{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attend Infer Repeat\n",
    "\n",
    "In this tutorial we will implement the model and inference strategy described in \"Attend, Infer, Repeat:\n",
    "Fast Scene Understanding with Generative Models\" (AIR) [1] and apply it to the multi-mnist dataset.\n",
    "\n",
    "A [standalone implementation](https://github.com/uber/pyro/tree/dev/examples/air) is also available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import os\n",
    "from collections import namedtuple\n",
    "from observations import multi_mnist\n",
    "import pyro\n",
    "import pyro.optim as optim\n",
    "from pyro.infer import SVI, TraceGraph_ELBO\n",
    "import pyro.distributions as dist\n",
    "import pyro.poutine as poutine\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn.functional import relu, sigmoid, softplus, grid_sample, affine_grid\n",
    "import numpy as np\n",
    "\n",
    "smoke_test = ('CI' in os.environ)\n",
    "assert pyro.__version__.startswith('0.3.0')\n",
    "pyro.enable_validation(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "The model described in [1] is a generative model of scenes. In this tutorial we will use it to model images from a dataset that is similar to the multi-mnist dataset in [1]. Here are some data points from this data set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "keep_output": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAABvCAYAAADfcqgvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFEBJREFUeJzt3X10z+Ufx/Hn1+ZmTYxtNoStECtsKtF9utUpKqEbpcahU2ed6oRyDkkqR1mpIznRoqJQ6iRJOW5S6hRRqrmpjSG3KTOWbPv98fld76EpZd/v5/vdXo9/Vttsl8tne3+v9/W+3legrKwMERER8U8NvwcgIiJS3SkYi4iI+EzBWERExGcKxiIiIj5TMBYREfGZgrGIiIjPFIxFRER8pmAsIiLiMwVjERERn0WH8psFAgG1+/qXysrKAv/1z2q+/70TmW/QnP8XesZDS/MdWsc731oZi4iI+EzBWERExGcKxiIiIj5TMBYREfGZgrGIiIjPFIxFRER8pmAsIiLis5CeMxaRypWens6mTZsA+PXXX30ejUj4O/XUUwG4+OKLmTNnDgCHDh0iLS0NgE6dOgFQs2ZN9u/fD8Dq1avZsWMHABs3bqSkpKTSx6VgLBKBatWqBcDAgQN5/vnngSODcb169bj//vsBmDt3LitXrgz9IEXCUN26dQFITk4mJSUFgDvvvJMrrrgCgKZNmwLw559/sm/fPgDy8/M5cOAAALNmzSInJ6fSx6U0tYiIiM+0MhaJQC7VlpaWRmlp6V8+3rlzZ7p16wbA5MmTQzo2kXD2ww8/APDbb7/x3HPPAdClSxe++eYbACZNmgTAypUr2bx5MwAZGRkkJycDUFxcHJRxKRiLRKBzzz0XgKioKLZt2/aXj7dt25aTTjoJgN9//z2kYxMJZ4cOHQK8n5HzzjsPgEWLFnHPPfcAsGfPHvvc888/H4Ds7GyysrIAb9snGJSmFhER8ZlWxiIRqF27dgBER0ezd+9ee78r7GrWrBk1a9YEoKxMF+2I/J06depw8sknA1jRVkZGBuPGjQOgsLCQrVu3BnUMER2Mo6KiAIiNjSU62vurFBcXW06/or00karABdoaNY5MbsXExADQokULS18fOnTIUtYJCQmAt192eBAXqW7y8/NZvHgxAFdddRWPP/44AB988AEAWVlZpKamApCZmRn0EwlKU4uIiPgsIlfGgYB3V7NL1T3yyCM0btwYgK+++spe7Xz33XeWWjh48GDoByoSJH/++Sfg/Sy4rFBpaamtjGvWrGnVnyNHjqRt27YA9rnZ2dksWrQo1MMWCRtr167l4YcfBryfi969ewNw6aWXApCYmMjo0aMBLKYEU0QHY/eL5eeff7YuRCkpKTz77LMAbNiwwSrf5s6dy8aNGwHtoUnkW7FiBQDXXXcdffr0AbwtmgsuuADwqq1dc4MzzzzTUmy5ubkArFu3LtRDFgk7Libk5ORw9tlnAxAfHw94+8gbNmwAoKioKOhjUZpaRETEZ4FQrhIDgUClfrPDV8juvxs2bMjpp58OwNVXX22ND7Zt28bUqVMBmDdvXsScvSwrKwv81z9b2fPtWseddtpptG/fHvAO0K9atQqA7du3V+a388WJzDdU/pwfy2mnnQZ4z3Lz5s0B2Lt3L1u2bAEgNTWVL7/8EoCbb77ZnvdwzAqF0zNeHWi+y7kK6kcffZQmTZoA3vYmwKBBg+znqUePHuzates/fY/jnW+tjEVERHwWkXvGjnuV74pZwFsBuxXaypUrWbJkCeA11B85ciQA55xzjrVBKygoCMvVQqhFR0cTFxcHeEfG3Jy2b9+eQYMGAViZf1JSEklJSYC3Gv7xxx8BePHFF/noo4+AI/9NpPLl5eUB0KtXL+rVqwfA/v377Zzx+PHjWbhwIeAdYxKRv7r22msBuOaaa6wD1xdffAF4nbjcZStnnXWW/W4LlogOxsfigmthYSELFiwAYP369fTq1QuAfv36WfX1I488Uq0Lu9w51bPPPpunn34a8KoICwoKAK8HsrtGzJ1bLSgosNR0amoqiYmJgHfziUvxuDmV4HBn6L/99tsj3l+/fn0ANm3aFDFbMSJ+SElJsQCcm5vL999/D5T3nn7ppZfIzMwEoGPHjkEPxkpTi4iI+KxKroydQCBgK7+NGzcyYcIEwGt3Nnz4cAD69+/PmDFjgNCUr4ebli1bAjBixAg7FgNYGnrs2LGWXVi7di1Q3i7uaHFxcdVyDsOJ68xVu3ZtdaAT+Rvt27e3LZ7hw4ezc+fOIz6ekpJihcEuUxhMVTIYu32ztLQ0MjIyAK/phzubOXv2bNv/7Nmzp+0rf/LJJz6M1l/9+vUD4MILL+T1118HYMKECbbPuGnTJktT/xPtTfrPNfpISUmx699E5K+6dOliTaGWLVtmfSvcAqVfv362demuXQwmpalFRER8ViVXxp06dQK8FKvrpnLyySeTn58PeCu/iRMnAt7NHHfddRdQPVfGp5xyCuAVr7lq6hYtWlinpuNdFUt4cMV0TZs2DfotMyKRrHbt2nbZUGZmprVXdvGjUaNGPProowBBvyQCqmAwrlWrFldddRXgBZIhQ4YA3m027ojOkCFDGDp0KOBVWbvKanezTXUKQG+++Sbg/d27du0KeGX87rjS4sWLrWL3008/Bbx0tPYjw5M7UrZmzRrdyiTyN6ZNm0bPnj0BuP32262aetasWQB8/vnnfP311yEbj9LUIiIiPovodpgVadCgAdnZ2YB3z/GAAQMArwK4devWALz88su28isqKrIiL1fMFE5nZIPdus5V38bGxlqj9AEDBthcNWjQwCoOXapmwoQJfPzxx0DVq0CPlHaYVYnaM4aW5ju0jne+q1yauqioiJ9++gnAelSD1yTB3ez02WefccsttwDwyy+/sHv3bsDruFLduLTmb7/9ZnvmS5Yssd7H7du3p3///kD5XkpOTo7tpUyZMqXKBWQRkVBTmlpERMRnVW5lfPDgQTug3alTJ2sPuHfvXlsFbty40aqIGzZsSE5Ojn2OeC0yD7/3dunSpQC0adMG8KrUR40aBXjZBFcEpn7UkcU1xImLi7Ptij179nDw4EH7eExMDAAlJSXWJlBEKl+VC8aAVcUNGDCAvn37AjBjxgxq164NQLNmzewXUUFBAdOmTfNnoGGqa9eudixm9erV1pPavX3ggQesQUh2djY///wz4KX/JXK4BiGTJk2iVatWAEyePJnFixcD3iXrV199NQCrVq2yn5Pq2MNdJNiUphYREfFZlVwZr1mzBoD33nuPu+++G4DLL7/cDninpaVZKm7atGn2+eIpLi5mxIgRgHer1fr16wFs/goLC60/dcOGDa3YSyvjyBEVFcU111wDeDd2uXZ/w4YNIysrC/Aa5TRo0ACAcePG+TNQkWqiSgZjt7eVk5Nj/XnT0tJsT/OLL76wAPLhhx9qL+woa9eutT6tI0eO5J133gG8X84At912m+0f5+fnW29viRz16tWjW7dugLdPfO+99wLeFo47aXDrrbfaz8/cuXOVnhYJIqWpRUREfFYlV8bO7t27mT17NuA1t3Cv7AOBgN3spDOyf7V161YGDx4MwPTp06093OFcY5Ts7Gy2b98e0vHJievcuTOXXnop4F2i7s7g5+bmWlvY66+/nilTpgDw5Zdf+jNQkf87/Pd2cnKyFeTGxMTQsGFDAMvo7d+/39r4FhYW+jDaf08rYxEREZ9VuXaYVY3frevatGnDbbfdBmBntj/77DO7G7qgoIA//vjjRL9N2Kjq7TCbN28OwKuvvsqpp54KQN++ffn8888Bry7AHVurX78+AwcOBLw6gmD9rvD7Ga9uImG+3ao3OTnZ6lMSExPtCF6nTp2sTW9cXJzdT+/+XH5+Pn369AEI6WUPFam27TClcuXm5jJ8+PAT+ho1a9bksssuA7zir9jYWKC80G7mzJm8++67JzZQOS6ucDE+Pp7x48cD3lly92/y0EMPWcC+//77WbduHaCzxRJ8gYAXs5o3b24V/eeffz5JSUmAd7OcWxC4JjUAf/zxh/VF+OabbwDvtrm8vLyQjb0yKE0tIiLiM62Mj8EVCjRp0oSUlBTAezW2c+dOwFsx6kjU8alVqxYdOnQAoHv37tStWxcob5+ZkJBg8z1v3jw7wyyVb/Xq1QAMHjyY5cuXA14Ro7sDfODAgTzzzDMALF26VPdWS8jUqVMHgKFDh1rnxJiYGDtel5eXZwW36enpxMXFAfDOO+8wZswYALZs2QLAgQMHrJdEpFAwrkBUVJRdJzhq1Ci6dOkCeGkSl7YbOXKkna/duXMniYmJAOzYsQOAQ4cOhXrYYauoqIiFCxcCcMYZZ9CyZUsA2rVrB3jtN5s2bQp4t0ctWLDAn4FWA7/++isACxYssJaw6enpPPHEE/Z+16u9pKTEn0FKteRekKelpdne79dff83YsWMBWLFiBQcOHAC8AJyeng54lf6uBXIkU5paRETEZ1oZV6Bu3brWRrNr1672/tLSUlvVPfXUU3z11VcAbN682W6BckUDixYtYtmyZYC3Mqzu6T5X0XjHHXdYdeSMGTMA785kl15y510luKKiosjIyABgzJgxtnUwadIkWz2LhJIrEjxw4ICdDX777beZP38+4BV4uer+1q1b2wU1K1eu9GG0lU/BuALp6elccskl9v9u72H9+vWWjm7RogVNmjQBvCDtUn4uPX3HHXcwb948AIYPH27BRrCmEu6AfmlpKT/99BNAxFVARqpzzjmHiRMnAtCoUSMeeughAJYvX17tXziKP9x+8IwZM+jYsSPgtWR1KeiioiL69+8PePvL7gjeqlWrKvx6rjo7NjaW/fv3A4T1s600tYiIiM+0Mq5At27dbAUM5bcRvfzyy9ZQPykpyVZ2h3Pvq1OnDt27dwdgypQpWhn/X1xcnM1L48aNAe/ijjfeeAMor7CW4HBZiV69etG6dWvAS1O///77gIq2xD/u2Zs/fz4XXXQRAH369LHq/q1bt9oW17p165gzZw6ArXqP5s7OX3vttfZ8h3P7YwVjsBSz20Pr0qWLBdWioiJLh3z77bdHpDlc56k5c+bwyy+/AF56GrymCuIJBAJ2FV9mZiaZmZlAecpo+vTpdswmnNNIVYE7wnTTTTcxc+ZMACZOnKjjZBI2du3aZQG4Ro0a3HjjjQCcfvrpRwTsbdu2/e3Xcc1rLr/8ctsyDGdKU4uIiPhMK2PK7+m98847AW9lHBUVBcCyZctYtGgRAPv27bMVXIcOHazyd+zYsbYydv1+e/ToEbLxh7vExEQGDRoEQFZWlrWye+WVVwCvYEPp0eCpUaMG5513HoA1R8jLy+Oxxx4DvJWISLgoLS21fg6jRo2ibdu2gFd06Cquk5OTrRnTmjVrKvw6rimIexvuFIwpr7o7usoXvFS0q6betWsXzz77rH2u+0fetm2bpaUbNWoUsnGHu5iYGAD69+9ve+3x8fF89NFHALzwwgsA/P777/4MsJqoVasWV155JVB+2cd9991Hfn6+j6MSOTa3XZWQkGDH7g6/QvGGG26wxdCIESMqrDVxX6OwsDAieqsrTS0iIuIzrYwP41bI7u3R/w2wfft2AEaPHm0r6R07dljLTFftFwgErI91OFfwBZO7evHuu+8mISEB8M5qv/baawCW2o+EV62R7ODBg0yfPh3AtlzcdotIOHJZtR49ethNY5s3b2bt2rUAdO7cmQsuuADwTmVU1CzI9S4YN27cMSuuw4lWxiIiIj7TyvgwboV2+ErtWKs2dyEEeHvMrhuX29MoLi5m9uzZAGzYsCEo4w1nKSkpdiQhKSmJH3/8EYAnn3yS9957Dyg/GibBVVpaSm5uLoC9FQlnLpOWnp5ul0Zs3brVjiidccYZJCcnA9CqVasKV8YuI7lnz56IyL4pGP+DVq1aWTXf9u3bj6j6dYVeXbt2tbOz7sGZN2+eNbKIhBRJZcvIyLB5Ky4u5q233gLg/fffr5bzISLHr6KFUWpqql2t2KBBA3sx706+HM3dI3DLLbcwevRoILy3DJWmFhER8ZlWxpS3YFy/fj3gpTVcx6g2bdrYq6qhQ4daU/KoqCg7uzlq1Cg6dOgAlF8qsWTJEis2iIQUSWUrLCy0FfApp5zChRdeCHhnitXtSUT+jit+/fTTT61QKyEhwdLXRUVFrF69GsDOJB9tz549gJfejoTfwYFQDjIQCITljLiKabcH8eCDD9oVirGxsXYT0w8//GABNjo62nr7tm3b1tLXS5YsAWDw4MH2sJyIsrKywD9/VsX8nO/ExEQGDBgAwLBhwywAz5o1y1LW7lrFcNo7PpH5hvB9xsNZpD7jkSqS5js+Pp6srCzAa/rhzg4vXLjQelMXFBRU2EbXtTmuXbs2xcXFgD8Lo+Odb6WpRUREfKaVcQWaNm1qxVcdO3a0DjCHCwQC9iqrpKTEzm8OGzYM8C68roxLDyLpVezRWrRoAcDUqVNp1qwZ4LWm27x5M4BVQJaUlLB3714AJk+ezNKlS30YrUcr49CL5Gc8Emm+Q+t451vBuALR0dG2B9y3b1969+4NeGlsl9Let2+fNQD55JNPyMnJAWDFihVA5V1FF8k/OC5NFB8fT2pqKgA9e/a0KxRda8aysjJ74TJkyBBmzJjhw2hxY1EwDrFIfsYjkeY7tJSmFhERiRBaGR+DWwE3atTI7jkeNGiQtcB844037LaQvLw8u+ygsu/jrWqvYuvUqWM3Wu3evRvwsghXXHEF4J3PXrZsmW/j08o49KraMx7uNN+hpTR1FaEfnNBSMA49PeOhpfkOLaWpRUREIoSCsYiIiM8UjEVERHymYCwiIuIzBWMRERGfKRiLiIj4TMFYRETEZyE9ZywiIiJ/pZWxiIiIzxSMRUREfKZgLCIi4jMFYxEREZ8pGIuIiPhMwVhERMRnCsYiIiI+UzAWERHxmYKxiIiIzxSMRUREfKZgLCIi4jMFYxEREZ8pGIuIiPhMwVhERMRnCsYiIiI+UzAWERHxmYKxiIiIzxSMRUREfKZgLCIi4jMFYxEREZ8pGIuIiPhMwVhERMRnCsYiIiI++x9GRfeHJefNbwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1055f1f60>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "inpath = '../../examples/air/data'\n",
    "(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)\n",
    "X_np = X_np.astype(np.float32)\n",
    "X_np /= 255.0\n",
    "mnist = torch.from_numpy(X_np)\n",
    "def show_images(imgs):\n",
    "    figure(figsize=(8, 2))\n",
    "    for i, img in enumerate(imgs):\n",
    "        subplot(1, len(imgs), i + 1)\n",
    "        axis('off')\n",
    "        imshow(img.data.numpy(), cmap='gray')\n",
    "show_images(mnist[9:14])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get an idea where we're heading, we first give a brief overview of the model and the approach we'll take to inference. We'll follow the naming conventions used in [1] as closely as possible.\n",
    "\n",
    "AIR decomposes the process of generating an image into discrete steps, each of which generates only part of the image. More specifically, at each step the model will generate a small image (`y_att`) by passing a latent \"code\" variable (`z_what`) through a neural network. We'll refer to these small images as \"objects\". In the case of AIR applied to the multi-mnist dataset we expect each of these objects to represent a single digit. The model also includes uncertainty about the location and size of each object. We'll describe an object's location and size as its \"pose\" (`z_where`). To produce the final image, each object will first be located within a larger image (`y`) using the pose infomation `z_where`. Finally, the `y`s from all time steps will be combined additively to produce the final image `x`.\n",
    "\n",
    "Here's a picture (reproduced from [1]) that shows two steps of this process:"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/html"
   },
   "source": [
    "<center>\n",
    "<figure style='padding: 0 0 1em'>\n",
    "<img src='_static/img/model-generative.png' style='width: 35%;'>\n",
    "<figcaption style='font-size: 90%; padding: 0.5em 0 0'>\n",
    "<b>Figure 1:</b> Two steps of the generative process.\n",
    "</figcaption>\n",
    "</figure>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inference is performed in this model using [amortized stochastic variational inference](svi_part_i.ipynb) (SVI). The parameters of the neural network are also optimized during inference. Performing inference in such rich models is always difficult, but the presence of discrete choices (the number of steps in this case) makes inference in this model particularly tricky. For this reason the authors use a technique called data dependent baselines to achieve good performance. This technique can be implemented in Pyro, and we'll see how later in the tutorial.\n",
    "\n",
    "## Model\n",
    "\n",
    "### Generating a single object\n",
    "\n",
    "Let's look at the model more closely. At the core of the model is the generative process for a single object. Recall that:\n",
    "\n",
    "* At each step a single object is generated.\n",
    "* Each object is generated by passing its latent code through a neural network.\n",
    "* We maintain uncertainty about the latent code used to generate each object, as well as its pose.\n",
    "\n",
    "This can be expressed in Pyro like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the neural network. This takes a latent code, z_what, to pixel intensities.\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.l1 = nn.Linear(50, 200)\n",
    "        self.l2 = nn.Linear(200, 400)\n",
    "\n",
    "    def forward(self, z_what):\n",
    "        h = relu(self.l1(z_what))\n",
    "        return sigmoid(self.l2(h))\n",
    "\n",
    "decode = Decoder()\n",
    "\n",
    "z_where_prior_loc = torch.tensor([3., 0., 0.])\n",
    "z_where_prior_scale = torch.tensor([0.1, 1., 1.])\n",
    "z_what_prior_loc = torch.zeros(50)\n",
    "z_what_prior_scale = torch.ones(50)\n",
    "\n",
    "def prior_step_sketch(t):\n",
    "    # Sample object pose. This is a 3-dimensional vector representing x,y position and size.\n",
    "    z_where = pyro.sample('z_where_{}'.format(t),\n",
    "                          dist.Normal(z_where_prior_loc.expand(1, -1),\n",
    "                                      z_where_prior_scale.expand(1, -1))\n",
    "                              .to_event(1))\n",
    "\n",
    "    # Sample object code. This is a 50-dimensional vector.\n",
    "    z_what = pyro.sample('z_what_{}'.format(t),\n",
    "                         dist.Normal(z_what_prior_loc.expand(1, -1),\n",
    "                                     z_what_prior_scale.expand(1, -1))\n",
    "                             .to_event(1))\n",
    "    \n",
    "    # Map code to pixel space using the neural network.\n",
    "    y_att = decode(z_what)\n",
    "\n",
    "    # Position/scale object within larger image.\n",
    "    y = object_to_image(z_where, y_att)\n",
    "\n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hopefully the use of `pyro.sample` and PyTorch networks within a model seem familiar at this point. If not you might want to review the [VAE tutorial](vae.ipynb). One thing to note is that we include the current step `t` in the name passed to `pyro.sample` to ensure that names are unique across steps.\n",
    "\n",
    "The `object_to_image` function is specific to this model and warrants further attention. Recall that the neural network (`decode` here) will output a small image, and that we would like to add this to the output image after performing any translation and scaling required to achieve the pose (location and size) described by `z_where`. It's not clear how to do this, and in particular it's not obvious that this can be implemented in a way that preserves the differentiability of our model, which we require in order to perform [SVI](svi_part_i.ipynb). However, it turns out we can do this this using a spatial transformer network (STN) [2].\n",
    "\n",
    "Happily for us, PyTorch makes it easy to implement a STN using its [grid_sample](http://pytorch.org/docs/master/nn.html#grid-sample) and [affine_grid](http://pytorch.org/docs/master/nn.html#affine-grid) functions. `object_to_image` is a simple function that calls these, doing a little extra work to massage `z_where` into the expected format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_z_where(z_where):\n",
    "    # Takes 3-dimensional vectors, and massages them into 2x3 matrices with elements like so:\n",
    "    # [s,x,y] -> [[s,0,x],\n",
    "    #             [0,s,y]]\n",
    "    n = z_where.size(0)\n",
    "    expansion_indices = torch.LongTensor([1, 0, 2, 0, 1, 3])\n",
    "    out = torch.cat((torch.zeros([1, 1]).expand(n, 1), z_where), 1)\n",
    "    return torch.index_select(out, 1, expansion_indices).view(n, 2, 3)\n",
    "\n",
    "def object_to_image(z_where, obj):\n",
    "    n = obj.size(0)\n",
    "    theta = expand_z_where(z_where)\n",
    "    grid = affine_grid(theta, torch.Size((n, 1, 50, 50)))\n",
    "    out = grid_sample(obj.view(n, 1, 20, 20), grid)\n",
    "    return out.view(n, 50, 50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A discussion of the details of the STN is beyond the scope of this tutorial. For our purposes however, it suffices to keep in mind that `object_to_image` takes the small image generated by the neural network and places it within a larger image with the desired pose.\n",
    "\n",
    "Let's visualize the results of calling `prior_step_sketch` a few times to clarify this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "keep_output": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAABvCAYAAADfcqgvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFHNJREFUeJzt3UlzG1X3BvBHrVZLljUklmVbHjMQgu1AEiigCqqANQULiiVLvgYfABZs+QYUX4AVLIAqoKAYk0pCgBASO3iOHI2WNfW76DqPJfi/4Y3545ad57eJSpYUqTWcPueee2/E932IiIhIeJywn4CIiMjDTsFYREQkZArGIiIiIVMwFhERCZmCsYiISMgUjEVEREKmYCwiIhIyBWMREZGQKRiLiIiEzD3I/ywSiWi5rwfk+35kv/fV8X5w/+R4Azrm+6HP+MHS8T5Y/+vxVmYsIiISMgVjERGRkCkYi4iIhEzBWEREJGQH2sAlInIYJBIJPPvsswCATCaDer0OAGi322g0GgCAZDLZd59oNAoA6HQ6GB4eBgBsbGwAAI4dO4ZWq8XbZrNZAMC9e/fw6aef/nsvRA4NBWMRkT85fvw43nrrLQDA7OwsA3Cj0UCpVAIQBFgLuolEApubmwCA4eFhBub19XUAQeB2XZeXPc8DAFy+fFnBWACoTC0iIhI6ZcYiIn8SjUYxPj4OICgpd7tdAEA+n0cikQAAuK6LmZkZAEClUuFtisUipqamAABnz54FACwtLcH3ff7dStypVOqAXpEMOgVjEZE/cRyHQbfZbPL67e1tVKtVAMDY2BjL12tra8jn8wCA5eVl3Lt3DwAwNDQEAIjH4xgZGeFjt9ttAMDq6uoBvBo5DFSmFhERCZkyYxGRP2k2m/jpp58AAKdOnWKjVjKZxNWrVwEAIyMj7JDO5/NsyhofH0c6nQYAXL9+HQDgeR47sl3X5eXjx48f0CuSQadgLCLyJ/F4HI888ggAoFQqsZw8PT3NwFytVlGr1Xj5zJkzAICdnR2Wp+0xqtUqlpaWAATl7bm5OQD9JXB5uKlMLSIiEjJlxiIheumllwAAL774IuehtlotVCoVAODc1enpac5vTSQSiMVivK118bqui8nJSQBArVZjo1Gr1eJjf/zxxwCAH3/88d9+aYdau91mE1YymWQ5OZvN4umnnwYQHGM79vF4nLfPZDI89seOHQMQdFAvLCwAAMrlMu7cuQMAePzxxw/oFcmgUzAWCdELL7wAAHj99dcxOjoKALhz5w4ymQwA4ObNmwCAXC7HkmahUGCwdhwHW1tbAIJpMp1OBwAwMzPD27fbbezu7gIAbty4AUDB+O84zl7RsFQq8URod3cX09PTAIC7d+8iEgl2x7OuaiB4H2zMeHt7GwBw+vRpvge+77Pz+sqVK//yK5HDQmVqERGRkCkzFgmRZcNbW1ssJafTaZZF7d+dnR2cPn0aQFAetdumUik2Bp06dYqZ9NbWFrt7NzY2mMHZ/eT+fN9nlSESifBYxuNxdkL3lqZjsRg7q8vlMteeLhQKAIIs2niex/fD5h6L6JspEqLff/8dQDBmbD/+7Xab6xWfPHkSQDBObD/cQ0NDHHMsl8t4+eWXAQD1ep3Xp9NplrpLpRJ2dnb42PL3HMfhClxDQ0MMntFolCc0pVIJExMTAMBhAyB4r2z6k91vamqqL/DeunULABjMRVSmFhERCZkyY5EQWZna8zyWNr/77ju89tprAPZ2/QGCjlwgKDXb2sa7u7v4/vvv+RjWvbu9vc0mrRMnTjArsy5sub92u80qw9jYGCsYzWYT58+fBxBUJaziEIlEeOyTySTfy5WVFQDA7du38ccffwAIMufFxUUAe1ssiigzFhERCZkyYxk4thdsOp3um2Ji42/dbpfX2zxPx3H4997re0UiEe6cA+xNR7GGnDDYc+h2u8ycpqam8PPPPwPYm2cci8Vw7tw5AMGmBDae+dFHH/Fyb3OW53m4ePEigGBcOh6P8/+Rv+c4DqsItVqNuzBVq1VuIJFOp5HL5QAE48c2Rr+zs8PGO+sDyOfzzJZ/+uknTnOy+4soGMvAsaal9957r6+8aj+OKysr7FK1MuHY2Bgvx2Ixzr3tdDoMUolEgh2vjuPggw8+AAC88847B/Gy/k9Wei6VSvyRT6VSbACy5RR938d3330HAJifn2cJ9fz58wzo9+7d4yIS6+vrnBt7/PhxnuBYeVvuz3VdPProowCC42od651Ohydv7Xab71+32+X1u7u7uH37NgBwTvL29jbfD8/z+uaGiwAqU4uIiIROmbEMHJvTOTc3x9JzpVJhZuw4DrNgyzyAvZJvs9nkQv31ep3Zy/j4OB9jY2ODpdtBUK/XmbVWq1Vm/lZWbjabfL5ra2ucojQ5OYnl5WUAwfGyjMx1XZTLZQDB8oyWdfeW8uX+rJRsQwUAuBkE0L8cpg0VAEHGbCtw2WcvGo3y/fA8j59VNdSJUTCWgdUbVKvVKpd3zGazLDeber3OzuQbN25w/G5iYoLjdr7v940TW+k2TL1zT+01ua7LQGCvKZ/PY21tDUBwYmKBu1KpMCDMzs5yR6ErV66wBNrtdlkitX/l/trtNndqSqVSPN6xWIzXA+Cylq7r9p3c2WfO/i2Xyzwh6h3bfxiD8RtvvIF3330XwN787Lm5OQ7N3L17l2P03W6X3/upqSl2p/u+zzH4aDT6l5OfYrGI559/HgD+8lsxqFSmFhERCZkyYxk4vdmEnRWPjIywgzoWi3HlIvt7IpFg6TqVSjEDLJfL7KCOx+N8jNHR0YEo2drz9zyP5XnLhoG9s/pGo8EsKpvN8rnX63Uer83NTT5ep9Ppe2wth/lgWq0Wl7C8cOECqw+NRoOXc7kcS89DQ0PcqclxHFZ0end1Mqurq9xdyz6zD5PepUPtu3n16lUeo2QyyYx5cXGRjXHXr1/n8Uyn0/zcLy8vsyJkx3NjY6Nv5sRhoG+mDJzeL5F9KXtLUeVyue/HEQh+PO0HMJVKscxbqVQ4nreyssJO7Rs3bgzExu72gwKApfV2u80FJOwHPhaL8fUnEgluat9qtVj+TCaT+OGHHwAEwcGCSS6XYzncbiv3l06nOTXs119/ZUl0eHgY8/PzAIBLly71lVN7A6t1wduiHjdv3uTnc2pqimP+vT0PD4ve4SELtI1GA7OzswCC7nVbh31zc5O3LxQKPJnu3SVrcXGR5Wt7n1qt1kCcbD8IlalFRERCpsxYBpbneSyrTk9PMxvMZDIs2fZmzpb13b17l5lxJBJhduJ5Hn799VcAQYZoZeFB4Ps+n3/vnGM708/lctyLuLchJRaLMePKZDLMtE+ePMnXqrnFD65er3Ne99DQEIc9fvjhB2Zt169fxy+//MLbWDXD8zy+V7aQy+joKJuLZmdn+fm1xV0eJo7jsCpl+z2n02nO5XYch1WdTCbDEvTS0hKPWywW4216Z0bYbQ9jY5yCsQwcK0VFo1EGktXVVQbjyclJ/rBZuS8ej7OkOz09zSDuui7LVbu7u3y8zc3NUFfeMvaa0uk0L09NTXFdaVt44ssvv2Sp+ebNmzh16hQfw37EisUiXnnlFQDAZ599xlJpvV7n2Kb9H3J/KysrePXVV8N+GkeS7/s8EbYVyGq1GsvR9+7d4xSyn3/+GY899hiA4ITnq6++AgA89dRTfduLWmC234PeMvZhoTK1iIhIyJQZy8Cx5qydnZ2+cq1luyMjI8z0LOvtdDpscLLSFxCUq2x+bj6f52PMz8/j2rVrB/Bq7s/mAruuy/mrnU6HjT22I1OxWGQZ1HVddpPPzMzgySefBBBkEZZRR6NRXLlyBQDwxBNP8DgN0kIn8nCqVCrcBcuGWM6dO8eS/gsvvMBhpkKhwMpPNpvFhQsXeD8rdTuOw6EDW/RmcnLy0DVwKRjLwLHxnuPHj3MMtFQq9Y0PW5nLbttqtViCvnv3Li9vbGz0TTmxYFwsFvuCdlisrLa1tcUAXKlUWJ62DutkMtnXbW0nLMVikaXnaDSKsbExAEG3uJW1b9++3ddlKhKmVCrFk0L7Dq6urjIwX7t2jUNI8Xgct27dAgAsLCxwTLjT6fBzXygU+Di2ct3W1tahm9qkMrWIiEjIlBnLwLES1cbGBkuzm5ubnNO5vr7OM2DLFre3t5n1Tk1Nsfu12+2yXLW8vIy5uTkAQUdr7+IaYbGz+94qQDqdZqOKvb5yucxqwLlz5/q2XrTMv16v45tvvgEQHEObt1mr1TjXepA6yOXh1LsIjWXDkUiEn+N4PM6GrEQiwQVSfN9nJalQKHCoxsrYwN6iH73z9w8LBWMZOFbCisVi7JoeGxtjYPI8j12YNvaUSqVYqmq1WuzIHh4e5vrBw8PDfLzV1VUG/TDZ66hWqyyt12o1/ujY2Hgul+MPV7FY5BZ8o6OjDMzRaJQnKY1Gg2Vqz/P4WtVNLWHrdDr8nNq4bzKZ5MId0WiUvRSLi4u4dOkSgGCoyoZnlpeXsb6+DiAIzFaStu99b8A/LFSmFhERCZkyYxk4VmIaGhpiWbXdbrP0nMlk2NRhmW48HudtK5UKG5Zc12XTUqlU6itfDcKZs1UBHnnkkb61pG0xiZmZGQBB5v/bb78BCF5/73KZ1nEejUZZqt/Y2GAX+YkTJ5QRy8Co1+us7Nj3tFqt8nIqlWLWWy6X+T3N5/OcizwxMcGtLVdWVnj92bNnARzORkVlxiIiIiFTZiwDx8ZLi8UipzJ4nsex306nwwzXVuHp3aWo0+lwfHVnZ4e3XVlZYdNWvV4fiBW4PvzwQwDBsom98ybtrN8aryqVChtWstksj1Hv5Xa7zepA77hzKpVi5nD58uWDeFki/1WpVMLNmzcBAM888wyA4HtqvQ/NZpPf9WvXrnG6nu/7rCStra1hYmICQFA9suYvy6jte3OYKBjLwOltvrIGp96g2u12/7Lrzfr6Orutd3d3+7qNe7cetEYPz/PYEBamb7/9tu9fkaPO8zwu0mHDMdYxDQTB2E6U0+k0mw8TiQSHpWKxGBYWFgAA33zzDZ544gkAe1uE3rp1S/OMRURE5MEoM5aBY+XaarXKsnIkEuHcQ8dx+kpXQDD16c6dOwD6S9bVapUlr2w2y+lPMzMz+OKLLw7oFYmI8TyP30PbX7xWq+HEiRMAguEk+17X63VO+btz5w5L06lUiruxtdttLiXb26x52CgYy8CxOYanT5/mPNtoNMqO4Hw+zy+dlbQ7nQ7HRYeGhvgFzmazfQsBPPvsswCCnY8GYcxY5GFk33GbCdBqtfj9Xltb4/XdbpcBOBqNckncRqPBoSrHcdhxbWPGmUxmIGZLPAiVqUVEREKmzFgGjnUNf/LJJ5xnu7u7y7PicrnMzdstQ/5ztmzdlZOTk+xMTiQS+Prrr/l4NtdRRA5Oo9HgLAnbZSmZTHIP4263yzn1Fy9e5JoCa2treO655wAAV69eZSm6UCiwgmZNmb7vH7oGLgVjGTg29vvmm2+G/ExE5P+b67p9HdJAMHPio48+AhAshGNjyZ9//jmnKZ44cYIL2czPzzMY907psyVgi8WiytQiIiLyYI58Znz+/HkAQWnEzpSi0SjPzBzH4TzU3r/3ljx7GwXssuu6XLR8c3OTnb4iIvLf/fHHH1zsxtYJ+P3339l06fs+s2Tf9zlz4tKlS2zQSiQSvP3S0hIX+bAmzng8zt/nwyJykHX1SCRy4EV829VnZWWF011mZmb4pna7XaTTaQB7HX6e5/FNrdVqHKuMRCJceCKVSjEwv/3223j//ff/lefv+/6+ay1hHO/D7p8cb0DHfD/0GT9YOt4H63893ipTi4iIhOzIl6ktA87lcuyqjcfj7NK9ceMGB/2tTO26LrtxXddlGXtra4tz3kZHR9lMoA3bRUTknzjywdg68azUDAQBenZ2FkAQgC3w2qowrusyQMdiMQbbiYkJLhRRq9WwtLTUdz8REZH9UJlaREQkZEc+M7ZJ4L1rnLquiytXrgAArwP6N3K3yeiRSISTx13X5ZZ99Xqdax7bTiEiIiL7ocxYREQkZEc+pbMmq1wux+lK0WgUZ8+eBYC+uWi28pPnebzfc889xz00a7Uab/P444/zfodtpRcRERksRz4YW3PV1NQUF/Ko1+sMoO12m+uk2tJs2WyW1zWbzb7FQKxMbTuMAHtLsImIiOyHytQiIiIhO/KZce8cYJvmdOzYMW5MPTw8zKlNtlvQ6uoqFhcXAQQNXpYFr6+vY2FhAUCwab2VuG3HEBERkf048sHY1jWtVCooFAoAgo3lbSGPRqPRN6cYCLbXswVCHMfhMpkjIyPczD6ZTHJ9a/u7iIjIfqhMLSIiErIjnxlbCXpiYoIb0pfLZUxOTgIIdlyyzNjmEzebTTQaDQDBjiA2j3h8fJwZc7lcZkasecYiIvJPHPkoYgE4nU6zK3p0dJSd1WfPnuXCIBaMk8kkx4HT6TRvm0gkEI/HAQA7OzvsvhYREfknVKYWEREJ2ZHPjG3z6t3dXXY/53I57ubUaDS4qId1Xg8PD7ODOplMsgu70+kwM/Z9n81ctnmEiIjIfhz5YGwd1AA43uu6LtLpNICgy3p8fBxAUHoGgoVA2u0272dBulgsotls8jobj+5dxUtERORBqUwtIiISsiOfGfeWkC1LTiQSbOa6fPkyy9BWmk4kEmzgOnPmDH777TcAwXxiK01XKhXOM7YlMkVERPbjoQnGhUKBpedr165x7HdoaIgbSFhX9ejoKLa2tnh/m7pUqVQ4DSqbzbI83btOtYiIyINSmVpERCRkRz4ztky2dwvFnZ0dJJNJAMF61NlsFsDeAiHr6+uYmZkB0N9B7TgOy9SO4zCTtutERET248gHY9v8IZPJcJoTAHZQe57HhUGsm9p1XXZQNxoN5PN5AMF0JhsnPnPmDEvZtliIiIjIfiilExERCVlEWZ2IiEi4lBmLiIiETMFYREQkZArGIiIiIVMwFhERCZmCsYiISMgUjEVEREKmYCwiIhIyBWMREZGQKRiLiIiETMFYREQkZArGIiIiIVMwFhERCZmCsYiISMgUjEVEREKmYCwiIhIyBWMREZGQKRiLiIiETMFYREQkZArGIiIiIVMwFhERCZmCsYiISMgUjEVEREKmYCwiIhKy/wDmgSdm2uGlvAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1165c9c18>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pyro.set_rng_seed(0)\n",
    "samples = [prior_step_sketch(0)[0] for _ in range(5)]\n",
    "show_images(samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating an entire image\n",
    "\n",
    "Having completed the implementation of a single step, we next consider how we can use this to generate an entire image. Recall that we would like to maintain uncertainty over the number of steps used to generate each data point. One choice we could make for the prior over the number of steps is the geometric distribution, which can be expressed as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "keep_output": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sampled 2\n",
      "sampled 3\n",
      "sampled 0\n",
      "sampled 1\n",
      "sampled 0\n"
     ]
    }
   ],
   "source": [
    "pyro.set_rng_seed(0)\n",
    "def geom(num_trials=0):\n",
    "    p = torch.tensor([0.5])\n",
    "    x = pyro.sample('x{}'.format(num_trials), dist.Bernoulli(p))\n",
    "    if x[0] == 1:\n",
    "        return num_trials\n",
    "    else:\n",
    "        return geom(num_trials + 1)\n",
    "\n",
    "# Generate some samples.\n",
    "for _ in range(5):\n",
    "    print('sampled {}'.format(geom()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a direct translation of the definition of the geometric distribution as the number of failures before a success in a series of Bernoulli trials. Here we express this as a recursive function that passes around a counter representing the number of trials made, `num_trials`. This function samples from the Bernoulli and returns `num_trials` if `x == 1` (which represents success), otherwise it makes a recursive call, incrementing the counter.\n",
    "\n",
    "The use of a geometric prior is appealing because it does not bound the number of steps the model can use a priori. It's also convenient, because by extending `geometric` to generate an object before each recursive call, we turn this from a geometric distribution over counts to a distribution over images with a geometrically distributed number of steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def geom_prior(x, step=0):\n",
    "    p = torch.tensor([0.5])\n",
    "    i = pyro.sample('i{}'.format(step), dist.Bernoulli(p))\n",
    "    if i[0] == 1:\n",
    "        return x\n",
    "    else: \n",
    "        x = x + prior_step_sketch(step)\n",
    "        return geom_prior(x, step + 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's visualize some samples from this distribution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "keep_output": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAABvCAYAAADfcqgvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAF0ZJREFUeJzt3UtvG+f1BvBnhsM7RUqU5EhmIl8UKZFrtzCatsiiCJou+pH6Fboq0FV3/QDdFF0UAYoCBVqg7aIB4sJJHCcq5YtkS5F1I8XrDNnF4Dx6Jxb+/ziRNaH8/DZRaF7EGYpnznnP+77eeDyGiIiIpMdP+xcQERF51SkYi4iIpEzBWEREJGUKxiIiIilTMBYREUmZgrGIiEjKFIxFRERSpmAsIiKSMgVjERGRlAXn+WKe52m5rxc0Ho+9b/pYHe8X922ON6Bj/k3oM36+dLzP19c93sqMRUREUqZgLCIikjIFYxERkZQpGIuIiKRMwVhERCRlCsYiIiIpUzAWERFJmYKxiIhIyhSMRUREUqZgLCIikrJzXQ5TJofnxSu4BUEA34+v2cbjceK/dj/3/097jvF4zJ/dx49GIwBAFEVn/NuLiEwWBWM51fvvvw8A+NWvfoVGowEA6Pf7AIC7d+8ik8kAABqNBobDIQAgn88zwO7u7mJ1dRUAsLm5iXq9DgDwfR/Hx8cAgN/97ncAgN/85jfn8ZZERL6zVKYWERFJmTJjOdXMzAwAoFAoYGNjAwCwuLgIIM6G9/b2AACVSgVBcPIxarfbAIByuYzd3V0AQK/Xw9bWFm+/fPkyACCbzb78NyIiMgEUjOVUNo7rjvf+97//BQBMT0/ztkwmg1wuByAuTR8cHAAAVldX8fDhQwBAGIa8z9OnT9HtdgEAR0dH5/RuLp58Pg8AePfddzkG7/s+hwn6/T56vR4AYHZ2lsfcxv+HwyGfwzUYDHhxFQQBBoMBgHio4cGDBy/xHYm82lSmFhERSZkyYzmVZUfT09OoVqsAgFarxdssc/7444/x1ltvAQCmpqaYgT1+/JiNWvV6nRnZ7du38fTpUwDAG2+8cU7v5uKxykS1WuXP7XYbpVKJt3c6HQDxUEKlUgFwkhmHYYgwDHmbDRl0Oh025xUKBQ47nJZFi8jZUTCWU1kwfvDgAUufxv3C/vnPf46PPvoIQBxcC4UCgHiceGFhAUBc+iyXywCADz74AGtrawDALmz55qIo4rnK5XIcDhiPxwzMtVqNQwIWoH3f57kKw5Cd8oPBALOzswDigG9B3J5XRF4OlalFRERSpsxYTvXll18CAB49esRyspUqy+Uys+V//OMfnIf87NkzZltRFLH5p9PpMHtbW1tj2Vulz2/P87xEtmtNW/l8nmXoVqvFUraVo4vFIkvQQRDwvGWzWWbJYRg+VxURkZdDmbGIiEjKlBnLqWq1GoC4iceafsz8/DyazSaAuJlrZ2cHQNwoZNNslpeXmRnfuXOHmfHh4SFu3rwJABx3lhfnLldqzXTZbJbHP4oinrcoip5bstQa7b76czabTWTJX50SJSIvh4KxnMoart5++21+2U9PTwMAms0mm3w+/vhjXLp0CQCwtbWFlZUVAPGc5Pn5eQBxY5cF47t37+LPf/4zAODJkyfn9G4uHnd9b7uoGY1GPM7u/PBMJpMoXwPJOcRBEPA52u02S9nj8Zg/KxiLvFz6CxMREUmZMmM5lWVSnucxI7b5wYuLi1yN6fbt2ygWiwDiRi2brlQoFPD48WMA8XQZawir1+t47bXXAIBTb+TFWWY8GAx4/MfjMTNct2mu1+txapI7NGA/R1HEJrBSqcSMGQCbwETk5VIwllPZ2tP379/Hu+++C+CkdL2xsYGpqSkAcaeuBd033niDZdBOp8OfW60Wd21qtVoscVsQkRfnlqDdLSjtmLrjxKPRiAuw2Hiw53m8bzabZTn6+PiYgTubzfLiyg3QInL2VKYWERFJmTJjOZV1U7/55pv47LPPAJyswhSGITeB6Pf7LGMPh0NmYG4GvLy8zLmrr732Gp/HzeguCiv91mq1RGOVzdcNwzDRIOVu8mDZrv0bcJIB+77P4+V5XuL57Ge3+7nb7SY2fLD72L8XCgVmvWEY8nV6vR5X3er1erzd/d1E5OwpGMuprCx5eHiIR48eAThZS3plZYXjx+vr6wzcT5484a5NpVKJ48Tz8/MM0jdv3mQX9UXcBcgWNHnvvfcYdHu9HgPp4eEhL1KA5AWJlfWHwyG7l91Aaxc02WwW6+vrAOLpZBZUe70ex35rtVoikNsFkwXoMAx53zAM+drFYpEB2+3UftWC8V/+8hf2NozHY67L/sknn/Cc3Lp1K1HGt4vMer2Of/3rX/wZABYWFvg3c/36dV5svf/++/ybkVebytQiIiIpU2Ysp7LsrVwuc1cm4y51+bOf/YxX9oPBAFevXgUQl7Stsevw8JCZ1wcffMCfl5eXX/r7OG+WSWazWXakZ7NZZrruIiqFQoEZVxAEie5ny5wse3V3Z3JL09vb2ywrAyfnLQxDdqvv7e3x+ey/buna8zz+rsPhkPeZmZlhw9er1vm+srKCu3fvAogrC3NzcwCAd955h5tuZLNZrK6uAojn2Nu5fPz4Mc+DnZvRaMT77uzssNqkbnUxCsZyqsPDQwDxl7MFgddffx1AvAa1leT+85//4MaNGwDiwGFd2Kurq/j8888BxOVXWwDkhz/8IcebL+K6x/aehsNhYmzcbi+Xy4lpSfZlPB6PeZ9cLvfcIhtBEPDL3vM8Ps5K4UB8nO1xmUyGr1MulxlAbIEWz/N4UTQajfg4z/MS5XX7ndxdnuzzMBqN+LO7yIh7H3d1r0ny8OFDXlju7u7yvedyOR7Ld955B//85z8BxOfbhhzcLSvNYDDghVW9Xk90w4sAKlOLiIikTpmxnMpKpqPRiPNRLVsOgiCxiIRlw/v7+8wOisUi16AuFApsOLJ/Ay7m3FU3o7UMKp/PMwvt9/vMWPP5fGKurzUDZTIZHhvbWanX6zFjzWQyieN/2l7Do9GI2XMURSwz25zx+fl5Zr3tdpuZ2nA4ZBOau9/0+vo6G5rsPboLi4xGI/4enufx+f7whz987WP3XRKGIc9BLpfD7u4ugPhvwBoZP/30UzbVBUHA41IsFvk3Yee6UqnwHPi+zyGcizijQL4ZBWM5lX2ZNhoNBuGPPvoIQNwZ+uzZMwDxF9XGxgbvax2jd+7cYWm61WrxS7vdbmN/fx8A+EV2kVgA63a7LEH6vs9u8vF4zKDV7XYTpWm3E93thAbiL3h3swf7Es9kMnzN0WiEmZkZAHFwt7K26/LlywDic2iPOzw8ZBkWQGK7Rbug6Ha7HFe2IOV2W7sl8u3t7Ylfd3x6eprHr9vtYnFxEUB8fq3cfOPGDX726/U6j8twOOR5cI+VfQZyuRz/NrTmtxh9EkRERFKmzFhO9fvf/z7xX3kx7jrRg8GAJcooiph5ug1cwMkcZbcxyl04xG2ssp/DMOQypZ7nMWMeDofM7Obm5nj/H//4xwDirmsrpfq+z6y72+2yHD0ajZjhvvnmm4llMgHg4OCA82jtvdnvZ3PMJ1WxWOS56ff7zIYXFxdZUXjy5Akb4vr9fqKZzpodrby9t7fHZrbPP/8ct27dAnAyHCSiT4LIGXK7kt0vc7cT+bTS5GAwSDzWgqMF42KxyC/2Wq3G5y4UConpVO44r3X0jsfj516z2+0mgoc75cluj6KIgX5vb4/vwZ2aZcMS9Xqd498XofT64YcfYmlpCUA8bGDHtdlscl32Tz/9FN/73vcAxFuGvv322wCApaUl/P3vfwdwspJdo9HA9vY2gPjiyKY2acxYzOT/1YiIiEw4ZcYiZ+i0NahnZmYS5WjrrHbv7y62kcvlEtsbAnHnsi1pCZzsvuSWr925rkByARIrs1rTXLFY5O9RrVbZtd1qtXj7YDDg63S7XWa87rrZ1qi0t7fH38/3/YmfP3vjxg2W5dfX1/k+r1y5wiGEWq3GbLdYLLKZCwB+8pOf8HYA+NOf/oS1tTX+ux1vtylPXm3KjEVERFKmzFgAxFmUNRm5Y5buylDufd2Vl9xMyb2v24Tk3n6RswLLHt09goMgYGZcLpc5tSubzbKB5/j4mBlpFEWJ8WYgzoRtapG741KhUGA2PB6P+XOhUOBjh8MhM1V7vUwmw9dwd2oKgoDT1vL5PJvKPM9jQ5iNmeZyOZ7LhYUF/v7lcpn3mWQ2FWlpaYnH8uDggI1vlUqFGfMnn3zCFbu2trZ4DO0zfuvWLd736OiIVYRJryDI2VEwFgDxbkq//e1vAcRfMvZF3W63ubuSfWEvLy+z/NbpdBgYoijivFS3+WdtbY0NK+PxGL/4xS8AYOI7bk9jwTAMQx6vTqeTKPfasev3+wyO1WqVjx0MBrzdAlyr1eKXubtV5dddVMNKru+99x6AuMTqBnS7ELNytr0Xuz2Xy3G9ZXc+rf37cDhkt/Bpi5BMmt3dXZ6D/f19BtWpqSmev3q9zgukq1evJi547CLFjnGlUklcKF2EYyRnS2VqERGRlCkzFgBxFmRX/4eHh4nyo82ltCwom80mllq0f3/69Cl/fvz4MVeU2tnZOXWe7UXkluYtK/I8L7EXsWVQ/X6f5ftqtcryZ6FQYFZmZe/p6Wlmnu7UocuXL+P73/8+gDgLs2z88PAwsRey3f7Tn/4UwMk82K8+DjjJ7qMo4ufgb3/7G7NAy9DL5TKzZLdpazQaTfwevVtbW9ypyd2kY3Z2NrFcqFvxsGNRKBRYprfzvrS0xGOZz+cv5CYp8u0oGAuA5DhwPp/nF3KpVOL4lgWLlZUVrq07Ho+5XGY+n2cwPjo64tjoYDDgc1/kQAycBMqvblHoBlf7UnYXfBgOh4l5xvazWxI+bZcl97z5vs8y88zMDINnqVTi/S2QtFot/lwul/k7hWGY6MK2uc1RFDFg2/BCNptN7OpkZexutzvxi1lcv36dpeSjo6PEBactjdloNLgoyvz8PBYWFgDEZW07LnaB22w2+XzD4ZBDOCJGZWoREZGUTfblq5yZfr/PTLZarfIqfjwe8+rfukW73S7Lzl/tiLbHuXNioyhKPN9F7KI2bhe6O1fYMqHj4+PEjkvm+PiY2WQURaxMuMteuptNWOaVyWQSOydZ2dSd15zNZpk9W7a8uLiYyMQt693d3WU5vFarYXl5GQDw73//m1m6/U7ZbJbNTDs7O3yc53kT3yW8ubnJMvXOzg53u3r06BH/HmZmZngsBoMB/vrXvwKIm7zs82/Vh/v373Pp0Ewmw0ZHlavFKBgLgHic69q1awDiL3v7stjc3MTKygqAk5Jpp9PB/fv3AcRjlvbFUywW+eXUbrcTY48WMNyS6UVk783tli2VSjx2BwcHDJhuwK5Wqwygg8GA97exynq9zpJxtVpNjM/a43K5HF/f3dJwOBwyaLrlU3uOMAwZSKemphILirjjw1YCtyUe3e0zc7kcy9fuGtmTqtFo8DNbr9exubkJIH5vdkHz7NkzfPHFFwCA1dVVLvQxGAz492HlfyvhA/F4vjuNUARQmVpERCR1yowFQJwpWTbmbmxQLpefa0YJggA3btwAkMziOp0Os4bhcMjbB4MBMy/bmP2ismPlbsTgzt21jAiIsyI7LsPhkN3XmUyGx9Gy22q1yjJ1EASJPYwtG+50OsxkW61WYiGPnZ0dAEgsPmJZbalUYiZbqVQSWbc9bjQaMSO21+71evyc9Pt97tHb7/cnfh5tr9djxSebzXKYYXt7m7taZbNZvPXWWwDi42lz6WdnZ9nwtbW1BSB5/C5dusTPhjJjMQrGAiAes7xz5w6AuPRsnaEHBwecpmEBIooijoVtbGzgypUrAOJAa18+1Wo1MeZmxuPxK/EF5E5h8n2fQev4+DjxRWyl0EKhwEAeBEEiEADxeXB3YbJg1+12E9Op7P6VSoXP3e12n+veDcOQF0uZTIa/n5VV7TnsAsz3fV44uAuV2Gfi4OCAFwVHR0cTPxYahmFiKpd7IWrnxvM8TuH68ssvE+PENq5sf0e9Xo/Hr1qt8sLlIg/ZyIvRJ0FERCRlyowFQJyZ2UIQ7prDV65c4c40VmLe2NhgMxFwMu/02rVrzKL39vbYnb22tsby6EXImv4vlv189tlniS5ne8+5XI73iaKImZHv+4n7u3sNA3GmZvOMwzDk8Xf3M3bXqXabqNw5zHZbu91mdm3NXUBc6rZServdZkn6+PiY3cC2kMXu7i6z60ajwd9pb29v4s9xs9nkUIy75Ovc3BybuVqtFjPcYrHI83f37t1ERQGIKwfXr1/nbTZEMOnHSc6OgrEASC7Y4Y4fN5tNlqGbzSaAeIzRgkWxWGTn9cbGBoNxsVhkkNje3mbJr1arXegytQUq66Z92RqNRmIalB3br37J28XQ5cuXAcQXBRZgyuUyL7jcADMajXg+G40Gn8sCTK/X4+fk3r17vO/MzAzHTyfVL3/5y8RwggVXd4ON0WiUmMJlx95dZc49H+5CKPZ8k75SmZwdlalFRERSpsxYAMRX85a9ViqVxAITdrtd5VcqlcR2epZB9Pt9ZludTofZWavVYgNRt9tVae4MudszBkGQ6JC2LD2Xy7HEbUMKs7OzvG1/f5/Z7qVLl9j9HQQBlzptNps8h3aO2+02G/lqtRqz5HK5PPHd1Pb5FjkvyoxFRERSpsxYAMSZ7NOnTwHE44ruXrc2JmzTNNxxsyiKmD15nsest9/vc3zS8zyuQPSqTG06L+70KM/z2Izl+z6z1m63yyzYqhj1ep23HR0d8XxmMhlWRWZnZ/l8xWKRc5jNzMwMG/36/T6z6263m9gFSkT+fwrGAiAuZVqTTqVSYUNPPp/n/EnrIr1+/ToDwOHhIf99bm4u0eVrjUWzs7P8gp/08uV3jbvoSjabZWPQaDTiBRXwfNn14OCAgdntcC+VSgzMm5ubiW017T42bNHr9XjfqampxJKf9vkRka9HZWoREZGUKTMWAHEGbNM07t27x6X7bt++zTL0actehmHIDOz+/fvMkvP5PB4+fAggzrasxLm9va0GrjOUy+Wea6wy7t7KluH++te/BhCfH8tq3f2Tu90uPwe9Xo+l7D/+8Y/PzZ11X9N9jX6/z7nUIvL1KBgLgOTOSq+//npivWMrPduX9GAwYKdupVLhuOHS0hKD+GAwYNm72WxyqU139xr59txdm4bDIQOiu8uSGzxtIRbgZO6x7/uJ8+1ucWljv+7jROTsqUwtIiKSMmXGAiDOkqzpplwuc2eara0tNm5Zw8+1a9dw8+ZNPtbmrpZKJdy7dw9AXOK05f/m5+eZpS0sLEz8xvPfJUEQsJTsZsYAEvOPrbPahhqiKEosnWnnxx2usKEIEXn5FIwFQPyFbSVmd5P6UqmEq1evAgAXlPB9H8+ePQOQHGM8Pj5mMPjBD36ABw8eAIiDsQVmdzEQ+faGwyE7pXO5HKciudOL3GUr3ZK27dTU6/W4lKXneYnubAVkkfOhMrWIiEjKlBkLgDgztuaqQqHAkvTGxgZ+9KMfAQC++OILAHFmbPu1ep7HRTwKhQKzLcu0gLjxy0qp1vglZ8P3fe6yBJzMJ/Z9P7G5gbuDExCXo91mL6t6uNl1EASJ5xaRl8dzx5he+ot53vm92AUxHo+/8XJVL3K8a7Uabt26BSD+IrfFOfb39zE3NwcAXLjD3Xjd7g8ktwqsVquJMUu7z2AwwIcffgjg+ak43wXf5ngD5/8Zr1arXBnNLf/7vp8IwHZx5U5PsvuHYcj7BkHArnp3tbT19fWX9h7O6zMuMR3v8/V1j7fK1CIiIilTZvwdp6vY8zVpmfFFoM/4+dLxPl/KjEVERCaEgrGIiEjKFIxFRERSpmAsIiKSMgVjERGRlJ1rN7WIiIg8T5mxiIhIyhSMRUREUqZgLCIikjIFYxERkZQpGIuIiKRMwVhERCRlCsYiIiIpUzAWERFJmYKxiIhIyhSMRUREUqZgLCIikjIFYxERkZQpGIuIiKRMwVhERCRlCsYiIiIpUzAWERFJmYKxiIhIyhSMRUREUqZgLCIikjIFYxERkZQpGIuIiKRMwVhERCRlCsYiIiIp+x9N6ru6GFvI9AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1164aeb38>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pyro.set_rng_seed(4)\n",
    "x_empty = torch.zeros(1, 50, 50)\n",
    "samples = [geom_prior(x_empty)[0] for _ in range(5)]\n",
    "show_images(samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Aside: Vectorized mini-batches\n",
    "\n",
    "In our final implementation we would like to generate a mini batch of samples in parallel for efficiency. While Pyro supports vectorized mini batches with `plate`, it currently requires that each `sample` statement within `plate` makes a choice for all samples in the mini batch. Another way to say this is that each sample in the mini batch will encounter the same set of `sample` statements. This is problematic for us, because as we've just seen, samples can make differing numbers of choices under our model.\n",
    "\n",
    "One way around this is to have all samples take the same number of steps, but to nullify (so far as is possible) the effect of the superfuous random choices made after the sample is conceptually \"complete\". We'll say that a sample is \"complete\" once a zero is sampled from the Bernoulli random choice, and prior to that we'll say that a sample is \"active\".\n",
    "\n",
    "The first part of this is straight forward. Following [1] we choose to take a fixed number of steps for each sample. (By doing so we no longer specify a geometric distribution over the number of steps, since the number of steps is now bounded. It would be interesting to explore the alternative of having each sample in the batch take steps until a successful Bernoulli trial has occured in each, as this would retain the geometric prior.)\n",
    "\n",
    "To address the second part we will take the following steps:\n",
    "\n",
    "1. Only add objects to the output while a sample is active.\n",
    "2. Set the log probability of random choices made by complete samples to zero. (Since the [SVI loss](svi_part_iii.ipynb) is a weighted sum of log probabilities, setting a choice's log probability to zero effectively removes its contribution to the loss.) This is achieved using the `mask()` method of distributions.\n",
    "\n",
    "(Looking ahead, we'll need to take similar measures when we implement the guide and add baselines later in this tutorial.)\n",
    "\n",
    "Of course, one thing we can't undo is the work done in performing unncessary sampling. Nevertheless, even though this approach performs redundant computation, the gains from using mini batches are so large that this is still a win overall.\n",
    "\n",
    "Here's an updated model step function that implements these ideas. In summary, the changes from `prior_step_sketch` are:\n",
    "\n",
    "1. We've added a new parameter `n` that specifies the size of the mini batch.\n",
    "2. We now conditionally add the object to the output image based on a value sampled from a Bernoulli distribution.\n",
    "3. We use `mask()` to zero out the log probability of random choices made by complete samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prior_step(n, t, prev_x, prev_z_pres):\n",
    "\n",
    "    # Sample variable indicating whether to add this object to the output.\n",
    "\n",
    "    # We multiply the success probability of 0.5 by the value sampled for this\n",
    "    # choice in the previous step. By doing so we add objects to the output until\n",
    "    # the first 0 is sampled, after which we add no further objects.\n",
    "    z_pres = pyro.sample('z_pres_{}'.format(t), \n",
    "                         dist.Bernoulli(0.5 * prev_z_pres)\n",
    "                             .to_event(1))\n",
    "    \n",
    "    z_where = pyro.sample('z_where_{}'.format(t),\n",
    "                          dist.Normal(z_where_prior_loc.expand(n, -1),\n",
    "                                      z_where_prior_scale.expand(n, -1))\n",
    "                              .mask(z_pres)\n",
    "                              .to_event(1))\n",
    "\n",
    "    z_what = pyro.sample('z_what_{}'.format(t),\n",
    "                         dist.Normal(z_what_prior_loc.expand(n, -1),\n",
    "                                     z_what_prior_scale.expand(n, -1))\n",
    "                             .mask(z_pres)\n",
    "                             .to_event(1))\n",
    "\n",
    "    y_att = decode(z_what)\n",
    "    y = object_to_image(z_where, y_att)\n",
    "\n",
    "    # Combine the image generated at this step with the image so far.\n",
    "    x = prev_x + y * z_pres.view(-1, 1, 1)\n",
    "\n",
    "    return x, z_pres"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By iterating this step function we can produce an entire image, composed of multiple objects. Since each image in the multi-mnist dataset contains zero, one or two digits we will allow the model to use up to (and including) three steps. In this way we ensure that inference has to avoid using one or more steps in order to correctly count the number of objects in the input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prior(n):\n",
    "    x = torch.zeros(n, 50, 50)\n",
    "    z_pres = torch.ones(n, 1)\n",
    "    for t in range(3):\n",
    "        x, z_pres = prior_step(n, t, x, z_pres)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have now fully specified the prior for our model. Let's visualize some samples to get a feel for this distribution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "keep_output": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAABvCAYAAADfcqgvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAGI1JREFUeJzt3VtvG9fVBuB3eBZFiiIlSpZsST6nTl2khhugLdoCRX9BUfQm6L/or+llbwoUvSrQqwK9CJCicIPYhuG4bWIrsa0zRYnnM4ffxWC92hMHHxJH1kjW+9xIoChSIoezZq299t7eZDKBiIiIRCcW9R8gIiJy3ikYi4iIREzBWEREJGIKxiIiIhFTMBYREYmYgrGIiEjEFIxFREQipmAsIiISMQVjERGRiCVO8sk8z9NyX9/SZDLxXvd39Xp/e9/l9Qb0mr8OHeMnS6/3yfqmr7cyYxERkYgpGIuIiERMwVhERCRiCsYiIiIRUzAWERGJmIKxiIhIxBSMRUREIqZgLCIiErETXfTjuP3yl78EAMTjcdRqNQBAPp/HcDgEACSTSUwmwRz1wWCAQqHA28fjMQCg3W7D84I52YlE8HIcHh4in88DAFKpFLa3twEAjx49Ool/S0REzhllxiIiIhHzLHM8kSc75qXUPvzwQwDAwsICNjY2AACrq6uoVqsAgGKxiG63y/u/++67AICXL18iHo8DAOr1OrNny4ZfvHiBy5cvAwA6nQ7+/ve/AwB+//vfH+ef/41o6bqTpeUwT56O8ZOl1/tkfdPX+0yXqa2svLS0hHa7DQBIp9NIp9MAgnJ0MpkEACwuLqLX6/H2g4MDAEAul2NgPjw8BBAE5ZmZGQDA/v4++v3+Cf1HIiJyHqlMLSIiErEznRnPzc0BALa2tliOHgwGbMj64osv8MMf/hAA0Ov1mAEXi0XU63UAQRZsme/U1BTva3zfZ9YtIiLyJpzpYGzBs1KpMOgOBgOOAfu+z8BcqVQ4Dux5Hi5dugQAqFar6HQ6AIALFy4AADY3N/kcCwsLLIeLiIi8CSpTi4iIROxMp3yffvopgKAJ6z//+Q8AYDKZMOut1+vMnhuNBj777DMAwM2bN7G3twcAoaz3448/BgD0+33e3u/3sbCwcAL/jYiInFdnOhjbWO+NGzcQiwVJ/vz8PDulS6USdnd3AQB37tzB+vo6AGA0GjHY9no9jj2/fPmSj7G0tAQAmJ6exr17907oPxIRkfNIZWoREZGInenM2ErQCwsLnE88Go3YNe37Pr788ksAwN7eHorFIgDgyZMneO+99wAAz58/Z4n79u3bAIBms8mFQz7//HNUKpWT+YdERORY3L17F0AwLFkulwEEw5ij0YjfW3V1ZmaGSyTb116vx1k6c3NzjAlra2u8fTgc4h//+AcAcBj0dSkzFhERidiZzoxLpRKAIANuNBoAgo0dbDx4bm6O84nX19eRzWYBAJlMhg1c5XIZmUwGADi+vLS0xDHoTCbD1bhERORssGmthUKBmwfFYjGe73u9HjcPSqfTzJJ3dnYAgNk0ALRaLeRyOT6uLZ3c6/WYaX9XZzoYWzkhFosxGI9GI74JzWYT09PTAICrV6+ytDA1NcWdmJaWlvhG2eM9ffqU61g3Gg0+hoiInA2pVApAMIxp5/v5+Xku6tTpdBiYB4MBf8+ad+PxOIN1v99nfOh2uxwi9X2fz/NdqUwtIiISsTOdGVsGfHBwwCuV3d1dlqlnZmbY2FUoFFiyLhQKvMqZnZ3llCcrQ9RqNZaxV1dXQ+UKERE5/WyosdlsMj40m02u0JjNZtnsO5lMGCtqtRofw2LM9PQ0Y0a73eaqjY1GQ2VqICgRAMD//vc//PjHPwYQlBOsrOz7PssQT58+xZ07dwAEHdQ23hyPx7moh5Ubbt68iYcPHwIIdm3a398/of9IRESOgw0/DodDBtrhcMiSdDabZcDu9/u8j8WPWq3Gn8diMcabbrfL+8zNzalMLSIi8rY405nx4uIigOAK5tGjRwCCsrMN1ufzec4RPjg4YFd0NpvFs2fPAAQD+nZFZB1yg8EA77zzDoAg67bmMBERORusxDwej0OlabdZy77PZDKsopr5+Xme+0ulEkvT2WyWWXc8Hn/l917XmQ7GVu+/dOkSg+v09DQD6dTUFO+Tz+f5wrZaLT7GaDRiMP7iiy8ABJO37QV2O+pERORssPP8eDxmWblarXJYcjQasbM6n89jY2MDQJDQAcG5337eaDSYzFUqFQbxVCrFseTvSmVqERGRiJ3pzPjw8BBA0AW9trYGAMxyAeCf//wnfvKTnwAAHj58iJs3bwIIrmysbJHJZJgx2zzk+fl5TgBfWFjgkpoiInI2WDbszqpJJBLsrO52u7y9UqmwK9oadsvlMve4HwwGnHWzsrLCpTF93z+2zPhMB2MrIczOzoZWR7HbC4UCA+0vfvELliFyuRyDbb1eZznDtl5stVpIp9N8HuuoEzlOxWKRQyrxeJxdmY1Gg2NSdlwPBgMOuXiexxNHOp3m8e7enkgkeMFZrVZ1QSnnjgXJwWDAAJxOp9FsNgEEnxe7Tz6f5+599vN6vc79DIbDIYP7ZDLh5zOVSqmbWkRE5G1xpjNjKyt3u12WE9y1Qi9dusT77OzshJbPtCXPut0uMw4blJ+bm2O2kcvlQqVvkeOyvLyMn//85wCCbNgy2WQyyeNvfn4eQHgd3d3dXTYYFgoFbG1tAQiGXOwYd3eZuX//vjJjiZzbFGvVRjvm7SsQlH6tE9rzPP7Mnesbi8V4rHuex/vbfbvdLu+bTCb5vZvVptNpLhDV7/dZIbVMN51O8/eAoyrVzs4OK6eJRCJ0n+/iTAdjC6jum1GtVlluGI/HXHWr0WjwBZydneWmEJ7n8Xa7byqVYnnCShYixy0ej/NCcH9/nz0Nvu+j3W4DAIdZyuUyLyxTqRQvHA8PDxmk3ZNYtVplWc26Q0Wi9Nvf/hZAcKHobrQABOtEu8M09n0mk+GUomKxGDpH22ckmUzy/nbbn//8Zz721tYWLl++zOdzV190n9P+Jku+fN/ncOZ4PA7NqrEAPBqNjm0FLpWpRUREInamM2O70kokEiw39Ho9Xu3EYjFmE8DRFU8sFmNm4fs+f9ctddhjjEaj0FqlIsfFXeM2n8/zOKvVaq9c6afTaR6/U1NT/N7zPJajXYlEgtWfWq2G69evAziamz8ej0ObrFu24H6W3LLg15XiPM8LZRH2+el0Onj58uVrvirytnKHXKwkbNvaAkdrQq+urjIDdvX7fZaKt7e3+XjA1x+f9tjlchmbm5sAws1crVaLx2wikeBnzTLgfD7Pz4jbMT07O8uK6WQyObYG3zMdjB8/fhz1nyDy2jzP4/jV1NQUT1DlcpnDJFamjsViDJKDwYAXme6JIJfLsTSXSCRYni6VSvjd734HAPjNb34DIFyu8zyPY8qLi4uhmQl2UVCr1VjGs/uWSiWe8FqtFm7fvg0AuHfvHteKFzEW2HK5XChJAoJgasH4xo0bPI5nZmYYHJPJJI/BfD6PixcvAgguMC042sXh/fv3Oa313XffZbl5ZmaG37sLQdlnyx4PCI5pW+Xx6dOn3Oe4UCgwoFerVf6t35XK1CIiIhE705mxyNui1+uFysN2xW6Zs1vOy+fzeP78Ob+3BsRut8vhl36/z6v3yWTCq33LINyyXr/fZ8nvwoUL/DtyuRz++9//AgiyhYODAwBHa8K7Xd2xWAz//ve/AQTLyYp8lR17rVaLFRc7Rufn53lbp9NhdSabzbKU3G63Wb5+//33OfQCILSIExBUTe/evQsA+OlPf8oKUrfbDX3O3GEWy9bts+b7PpfAXFxcxF//+lf+HZbF5/N5zTMWERF5WygzFomI2zjlbl7ueR4zWWvw6vf7oSYrWy0OQGgXGruPu4D94uIim7zsqr/T6XBcut1uMxPvdDqcMri5ucn7p1KpUHYOBONzlrkvLCy88hwiLsuMb9y4waWM3bUiLOtNJBKh3ge3mcuOsTt37nAseWtri/0RVvFZWFgIrdBo49WlUolZ7WQyYUY9Pz/PHf5sSuDBwQE/F6PRiPeNx+Oh51MDl8gZl0gk+EEeDAY8SZRKJd7HSmDJZJKNJy9evMDKygrvYyeg4XDIkrW7ofr29jZLdqZUKvFE2Ol02Ekai8V4MpqammJAH4/HvEBwdz+zny8tLTGI2+5nIi47lg8PD3nM2rE+Go1YEm40GrxQnJ2dDTVULS8vAwiGU2yxm/39/Vfm1KfTaX62ZmZmQg1e9rloNBqhwGtrUltZPBaLhWY72EXB4uIiA7Pnece26IfK1CIiIhFTZiwSkX6/H9rUxLKE8XjMLNkaXJrNJq/AO50Of55IJJgxd7tdlMtlAEG2bI89GAxemZbk7jyTSqXY1HLx4kXuDT6ZTPD9738fAPDJJ5/g2rVrAMJzla1E/vDhQ658ZxmOiMuOm0wmw+Nmb28PQLAqlx3fjUaDQx3r6+usuKTTaT7G5uYm3nvvPQDBZ8SyWcuou90uh3JisVhodS2bc+xuGuH7Plfpss9TPB5n9u37PqtO6+vruHXrFoBgyt9xrcClYCwSkUwmw5MHcFRuTiQS7Ca1k8/s7CxPEqlUigt6ZDIZluY8z+Nc5Uqlws7S6elpnlR+8IMfAAhOZhZo3R2h9vf38b3vfQ9AcOKyoH/lyhX+fTa/s9vthsbirMxnJzURl9u5bN+7607bMegG5kwmw2MwFosxIF66dIll46WlJR57Ftzj8Tg++ugjAMGx/uDBAwBBedsuBNyu7Vwux3FsuwDudDq8KNjY2MBf/vKXY35FwlSmFhERiZgyY5GIdDodlpKnp6fZRDUej1l2syv3er3O5pVkMslS8GAwYKNKr9djw8mNGzeYafR6PTZrWSaytLTEecOWQRvLMizLBoJs127/17/+xb95YWEBQNAktra2BgBaClO+lh1n2WyWx6E1dXmeF5rfbkajUWi4xTqo3eUwK5UK577bY7grxm1vb3NOslvNcf8Od4cme75MJsPPjfs3vSkKxiIRSaVSnBrkbis3MzPDkpm7HOb29jaAYBqGnTAmkwlPcsvLy+xurlarLF/H43GW3iyIj0YjLu9XLBYZ6O1nQHgdX+Coy9RObAcHB6ExN/u9ry51eJr8+te/BhCUQt31hr9ufW878bdaLXbdep7H92FxcZEn/GazyfchkUjgD3/4w5v7J84oKzdns1ke93Z8x+Nx9kz0+/3QdCF3LWl7z3Z3d7G6ugogCJpWnrZx4lQqxds6nQ4vFN1jenp6mhekm5ubfP6nT58CALflPSkqU4uIiERMmbFIRNz9jPf29r527qJlwPl8nrf1+32WjEulEjPSVqsVKsFZ9tFqtZgRWyl8eXmZTTSpVIoLHvR6PWaEiUSCj/HgwQNmHdbU0u/3mS0fHh6GFm04rT744AMAwN27d9nI02g0mO3a65ROp5kZbWxs8P9053vncrnQDnHuBvbKjF9lx/fOzg6zWneRGjvWSqUSGw7L5TIz1k6nw/cslUrhnXfeARC8P1ahcBepsTJ2o9Fgw2OhUGAFYzQahZocrRplv3dwcMDyt/3sTTq9nxqRt5y7MEcul+NJvlAo8CRlJyh3ndxarcYStLtjjJ3sgCA4WHAcDAYsx1m51fd9XgjE43E+z9zcHANMLpdjkHbHzOyxKpVKaFqHuy3paWX/x7Nnz/h6zszM8H2wBUtu3brF9+PatWt8HXzf53szmUwYxNvtNt8Td2s/OWLH1crKCodI3J/Z8Ear1QoFbjvGfN/n7Wtra7xAun37NoPl559/DiDcSxGLxRjQ3SmCbld3u93m89hsgVgsxuc7iel6KlOLiIhETJmxSETcuZXj8ZiNJTs7Oyy3WfltPB7zSn91dZVl6slkwsYY3/f5e1tbW5zvu7W1xQYld71eK91NJhNmzNlsllnL8vIyy3TVapVlassSL168yIzafYzjWqv3TbD/+datW8zkt7a2mAVbdjQ3N8fXYX19nf/75cuXeftHH33EOdmVSoWZmi2sImHurk2WkbrNgvZ+5HI5HleZTIavK3BU1m40Gsykd3d3X2lQXF5e5vuQy+VCi4K4603b4/m+z2PAqh2ZTIafl5M4phWMRSKSSCR4kpiamuJYa6FQ4MnIDAYDnnz29vZYaovH4zyJlctllkrdxfXH4zGuXr0K4OikUq/XGeir1SrHkt1NJxqNBst7nueFyt5AcGKzRUs8zwv9faeVvW6xWIwdttVqleVI97V0O33dMujNmzcBBO/TkydP+L1duNiYu4TZxVoikeBFox0z7kYp8/PzvMAsFAqhFebs89LpdEKlbnfDB/u5lcVHo1HootE+Z+l0mo/XaDRCMweAoHRuUwWPa5Wt/8/pvYQVERE5J5QZi0TEbRAZDofMFnq9HrMB+zo9Pc0r+nw+H9og3dbu7XQ6LAVeuXKFJfBUKsWM2K7+d3d3mdVmMhlmdfV6ndnC9vZ2aMvFrzZouU1nqVSKzTNuI9lpY+VR4KhD1l3cxDK158+fY2NjA0BQjrdhAeCoVHrt2jWWvePxOB/nJBaIOIvcY9qyVnd4wDqbe70eX+Nms8lj2q0CZbNZHtOJRIKPZ++Zu767e5wmk0ne1x36KRaLoe+B4Di23/u6eejHTZmxiIhIxJQZi0SkUqng4cOHAILM0h0PswzOnVpki+S7Swe6Y13JZJLjnMVikWO8jUYDv/rVrwAcjSVfvnyZzSnz8/PcyWYwGPB50uk073PhwgV+b9nC4eEhx6gfPHjAv/k0zzO2jKvb7XI80PM8ZrM2rvnpp5/iRz/6EYAgO7PqQ71eD42JWzYXj8dZRbCKhIRZxrmyssKqhL3utVqNfQtTU1OsUHiex8/F9PQ0V4rL5/Ohfb+N9UHs7u7y/ajVaqHpdvYeDwYDHtOpVIrP7y5Fa5+FkzimT++nRuQtV6lU8OGHH57oc9rJyj0RdTodnoC63S63SrQTERB0CNuJ006k7gkqkUjw5Hia5xlbAC6Xy2wAGg6HvKCxDvRGoxHqQLdAMh6PeZIvl8v47LPPAAQB2E7ybilcjthF3JMnT0JDK0B4nvH9+/e5KEitVuNx2Gg0OASSzWb52nc6HS7QYl9XV1d5Iesue5lMJkNlb/dC1u3UBoJZDSc5M0BlahERkYgpMxY5B9wmJSDILCyT6/V6zBzm5uZC8y3n5uYAhLMIyz7a7TankvR6PZYQ19fX3/B/8/pWVlYABFmaO4fbsmD738bjMV+H9fV1TldqNpt4//33AQCPHj3iak0zMzOcPnMSqzWdRX/84x+P7bE+/vhjzvEGjhrv3M0drOkunU6HdnWyilCtVuNnADha7c7NxN2pUm+agrHIOWLl1l6vx67SyWTCAFKv19mRmslkGJB6vR7Hm23bRHf8dGVlhWVrdzek08YWgrh+/TrL8V9++SVPuu4axlZuX15e5sm+WCxyycypqSme8JvNJi9W3M5reTPcvopSqRQKsEBwbNqFEnB03D979oy35XK50K5oVgK3C69ut6sytYiIyHmizFjkHPjTn/4EIOh6BoJynXX99vt9dgW7yzpa5gwETVlW6rYO6vF4zIykUCgwi7Bs+jSyeaoHBwfM8Pv9PkuUbuneMuPZ2VlmvYlEgv+nuyJUrVYLraAmb5a7el29Xg/tNAYE2bLbHPbixQsAQVXHMuB+v8+mw1wux4qGVXbW1ta4e5RVTN7o//TGn0FEIvf48ePQ1/PKgurq6ioXKfF9/5ULiUwmwyUQu90uL1xarRaD+MzMDH8vk8nwJO6upSxvRqvVYrd0NptFuVwGcLRsZbFYZNf0cDhk53y73eZF5mAwCO1cZn0T9hjz8/O8SP26aVTHTWVqERGRiCkzFpFz48qVKwCCrPbevXsAgnKkNbBZdru1tcUMKhaLscHL930ukFIsFvn9ysoKH8MyNnlzHj9+/NZVeRSMReTcePnyJQBgc3MT169fBxCMKdr4sJUor169ik8++QRAMGZs44yTyYTfV6tVTv2qVCosc9o0GpFvQ2VqERGRiCkzFpFzw5Zh3N3dZfNVKpXiHGrrxu10OlzcI5lMhtbjtsVCpqen2YHr7pl7mudZy+mlYCwi54YtaPKzn/2M3bbdbpdd1lam9n2f3+/u7nJxk6WlJW4osLy8zNWaDg8PuXb3V1c7E/kmVKYWERGJmDJjETk3bN5oOp3mcoru0opWjh4Oh+yU7vV6XAwklUqFtl58/vw5gKDj2pZctHK1yLehYCwi54ZtVvC3v/2NCz64wdM6pYfDIYOu7/ucthSLxUL7QNvKTOl0mgFbwVheh8rUIiIiEfPsSlBERESiocxYREQkYgrGIiIiEVMwFhERiZiCsYiISMQUjEVERCKmYCwiIhIxBWMREZGIKRiLiIhETMFYREQkYgrGIiIiEVMwFhERiZiCsYiISMQUjEVERCKmYCwiIhIxBWMREZGIKRiLiIhETMFYREQkYgrGIiIiEVMwFhERiZiCsYiISMQUjEVERCKmYCwiIhIxBWMREZGI/R8nQmNyVYqNZwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x114de2048>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pyro.set_rng_seed(121)\n",
    "show_images(prior(5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Specifying the likelihood\n",
    "\n",
    "The last thing we need in order to complete the specification of the model is a likelihood function. Following [1] we will use a Gaussian likelihood with a fixed standard deviation of 0.3. This is straight forward to implement with `pyro.sample` using the `obs` argument.\n",
    "\n",
    "When we later come to perform inference we will find it convenient to package the prior and likelihood into a single function. This is also a convenient place to introduce `plate`, which we use to implement data subsampling, and to register the networks we would like to optimize with `pyro.module`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    # Register network for optimization.\n",
    "    pyro.module(\"decode\", decode)\n",
    "    with pyro.plate('data', data.size(0)) as indices:\n",
    "        batch = data[indices]\n",
    "        x = prior(batch.size(0)).view(-1, 50 * 50)\n",
    "        sd = (0.3 * torch.ones(1)).expand_as(x)\n",
    "        pyro.sample('obs', dist.Normal(x, sd).to_event(1),\n",
    "                    obs=batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Guide\n",
    "\n",
    "Following [1] we will perform [amortized stochastic variational inference](svi_part_i.ipynb) in this model. Pyro provides general purpose machinery that implements most of this inference strategy, but as we have seen in earlier tutorials we are required to provide a model specific guide. What we call a guide in Pyro is exactly the entity called the \"inference network\" in the paper.\n",
    "\n",
    "We will structure the guide around a recurrent network to allow the guide to capture (some of) the dependencies we expect to be present in the true posterior. At each step the recurrent network will generate the parameters for the choices made within the step. The values sampled will be fed back into the recurrent network so that this information can be used when computing the parameters for the next step. The guide for the [Deep Markov Model](dmm.ipynb) shares a similar structure.\n",
    "\n",
    "As in the model, the core of the guide is the logic for a single step. Here's a sketch of an implementation of this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def guide_step_basic(t, data, prev):\n",
    "\n",
    "    # The RNN takes the images and choices from the previous step as input.\n",
    "    rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)\n",
    "    h, c = rnn(rnn_input, (prev.h, prev.c))\n",
    "\n",
    "    # Compute parameters for all choices made this step, by passing\n",
    "    # the RNN hidden start through another neural network.\n",
    "    z_pres_p, z_where_loc, z_where_scale, z_what_loc, z_what_scale = predict_basic(h)\n",
    "\n",
    "    z_pres = pyro.sample('z_pres_{}'.format(t),\n",
    "                         dist.Bernoulli(z_pres_p * prev.z_pres))\n",
    "\n",
    "    z_where = pyro.sample('z_where_{}'.format(t),\n",
    "                          dist.Normal(z_where_loc, z_where_scale))\n",
    "\n",
    "    z_what = pyro.sample('z_what_{}'.format(t),\n",
    "                         dist.Normal(z_what_loc, z_what_scale))\n",
    "\n",
    "    return # values for next step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This would be a reasonable guide to use with this model, but the paper describes a crucial improvement we can make to the code above. Recall that the guide will output information about an object's pose and its latent code at each step. The improvement we can make is based on the observation that once we have inferred the pose of an object, we can do a better job of inferring its latent code if we use the pose information to crop the object from the input image, and pass the result (which we'll call a \"window\") through an additional network in order to compute the parameters of the latent code. We'll call this additional network the \"encoder\" below.\n",
    "\n",
    "Here's how we can implement this improved guide, and a fleshed out implementation of the networks involved:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn = nn.LSTMCell(2554, 256)\n",
    "\n",
    "# Takes pixel intensities of the attention window to parameters (mean,\n",
    "# standard deviation) of the distribution over the latent code,\n",
    "# z_what.\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.l1 = nn.Linear(400, 200)\n",
    "        self.l2 = nn.Linear(200, 100)\n",
    "\n",
    "    def forward(self, data):\n",
    "        h = relu(self.l1(data))\n",
    "        a = self.l2(h)\n",
    "        return a[:, 0:50], softplus(a[:, 50:])\n",
    "\n",
    "encode = Encoder()\n",
    "\n",
    "# Takes the guide RNN hidden state to parameters of\n",
    "# the guide distributions over z_where and z_pres.\n",
    "class Predict(nn.Module):\n",
    "    def __init__(self, ):\n",
    "        super(Predict, self).__init__()\n",
    "        self.l = nn.Linear(256, 7)\n",
    "\n",
    "    def forward(self, h):\n",
    "        a = self.l(h)\n",
    "        z_pres_p = sigmoid(a[:, 0:1]) # Squish to [0,1]\n",
    "        z_where_loc = a[:, 1:4]\n",
    "        z_where_scale = softplus(a[:, 4:]) # Squish to >0\n",
    "        return z_pres_p, z_where_loc, z_where_scale\n",
    "\n",
    "predict = Predict()\n",
    "\n",
    "def guide_step_improved(t, data, prev):\n",
    "\n",
    "    rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)\n",
    "    h, c = rnn(rnn_input, (prev.h, prev.c))\n",
    "    z_pres_p, z_where_loc, z_where_scale = predict(h)\n",
    "\n",
    "    z_pres = pyro.sample('z_pres_{}'.format(t),\n",
    "                         dist.Bernoulli(z_pres_p * prev.z_pres)\n",
    "                             .to_event(1))\n",
    "\n",
    "    z_where = pyro.sample('z_where_{}'.format(t),\n",
    "                          dist.Normal(z_where_loc, z_where_scale)\n",
    "                              .to_event(1))\n",
    "\n",
    "    # New. Crop a small window from the input.\n",
    "    x_att = image_to_object(z_where, data)\n",
    "\n",
    "    # Compute the parameter of the distribution over z_what\n",
    "    # by passing the window through the encoder network.\n",
    "    z_what_loc, z_what_scale = encode(x_att)\n",
    "\n",
    "    z_what = pyro.sample('z_what_{}'.format(t),\n",
    "                         dist.Normal(z_what_loc, z_what_scale)\n",
    "                             .to_event(1))\n",
    "\n",
    "    return # values for next step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we would like to maintain differentiability of the guide we again use a STN to perform the required \"cropping\". The `image_to_object` function performs the opposite transform to the object_to_image function used in the guide. That is, the former takes a small image and places it on a larger image, and the latter crops a small image from a larger image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def z_where_inv(z_where):\n",
    "    # Take a batch of z_where vectors, and compute their \"inverse\".\n",
    "    # That is, for each row compute:\n",
    "    # [s,x,y] -> [1/s,-x/s,-y/s]\n",
    "    # These are the parameters required to perform the inverse of the\n",
    "    # spatial transform performed in the generative model.\n",
    "    n = z_where.size(0)\n",
    "    out = torch.cat((torch.ones([1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)\n",
    "    out = out / z_where[:, 0:1]\n",
    "    return out\n",
    "\n",
    "def image_to_object(z_where, image):\n",
    "    n = image.size(0)\n",
    "    theta_inv = expand_z_where(z_where_inv(z_where))\n",
    "    grid = affine_grid(theta_inv, torch.Size((n, 1, 20, 20)))\n",
    "    out = grid_sample(image.view(n, 1, 50, 50), grid)\n",
    "    return out.view(n, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Another perspective\n",
    "\n",
    "So far we've considered the model and the guide in isolation, but we gain an interesting perspective if we zoom out and look at the model and guide computation as a whole. Doing so, we see that at each step AIR includes a sub-computation that has the same structure as a [Variational Auto-encoder](vae.ipynb) (VAE).\n",
    "\n",
    "To see this, notice that the guide passes the window through a neural network (the encoder) to generate the parameters of the distribution over a latent code, and the model passes samples from this latent code distribution through another neural network (the decoder) to generate an output window. This structure is highlighted in the following figure, reproduced from [1]:"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/html"
   },
   "source": [
    "<center>\n",
    "<figure style='padding: 0 0 1em'>\n",
    "<img src='_static/img/model-micro.png' style='width: 35%;'>\n",
    "<figcaption style='font-size: 90%; padding: 0.5em 0 0'>\n",
    "<b>Figure 2:</b> Interaction between the guide and model at each step.\n",
    "</figcaption>\n",
    "</figure>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From this perspective AIR is seen as a sequential variant of the VAE. The act of cropping a small window from the input image serves to restrict the attention of a VAE to a small region of the input image at each step; hence \"Attend, Infer, Repeat\".\n",
    "\n",
    "## Inference\n",
    "\n",
    "As we mentioned in the introduction, successfully performing inference in this model is a challenge. In particular, the presence of discrete choices in the model makes inference trickier than in a model in which all choices can be reparameterized. The underlying problem we face is that the gradient estimates we use in the optimization performed by variational inference have much higher variance in the presence of [non-reparameterizable choices](svi_part_iii.ipynb#Tricky-Case:-Non-reparameterizable-Random-Variables).\n",
    "\n",
    "To bring this variance under control, the paper applies a technique called \"data dependent baselines\" (AKA \"neural baselines\") to the discrete choices in the model.\n",
    "\n",
    "### Data dependent baselines\n",
    "\n",
    "Happily for us, Pyro includes support for data dependent baselines. If you are not already familiar with this idea, you might want to read [our introduction](svi_part_iii.ipynb#Baselines-in-Pyro) before continuing. As model authors we only have to implement the neural network, pass it our data as input, and feed its output to `pyro.sample`. Pyro's inference back-end will ensure that the baseline is included in the gradient estimator used for inference, and that the network parameters are updated appropriately.\n",
    "\n",
    "Let's see how we can add data dependent baselines to our AIR implementation. We need a neural network that can output a (scalar) baseline value at each discrete choice in the guide, having received a multi-mnist image and the values sampled by the guide so far as input. Notice that this is very similar to the structure of the guide network, and indeed we will again use a recurrent network.\n",
    "\n",
    "To implement this we will first write a short helper function that implements a single step of the RNN we've just described:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "bl_rnn = nn.LSTMCell(2554, 256)\n",
    "bl_predict = nn.Linear(256, 1)\n",
    "\n",
    "# Use an RNN to compute the baseline value. This network takes the\n",
    "# input images and the values samples so far as input.\n",
    "def baseline_step(x, prev):\n",
    "    rnn_input = torch.cat((x,\n",
    "                           prev.z_where.detach(),\n",
    "                           prev.z_what.detach(),\n",
    "                           prev.z_pres.detach()), 1)\n",
    "    bl_h, bl_c = bl_rnn(rnn_input, (prev.bl_h, prev.bl_c))\n",
    "    bl_value = bl_predict(bl_h) * prev.z_pres\n",
    "    return bl_value, bl_h, bl_c"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two important details to highlight here:\n",
    "\n",
    "First, we `detach` values sampled by the guide before passing them to the baseline network. This is important as the baseline network and the guide network are entirely separate networks optimized with different objectives. Without this, gradients would flow from the baseline network into the guide network. When using data dependent baselines we must do this whenever we feed values sampled by the guide into the baselines network. (If we don't we'll trigger a PyTorch run-time error.)\n",
    "\n",
    "Second, we multiply the output of the baseline network by the value of `z_pres` from the previous step. This relieves the baseline network from the burdon of having to output accurate predictions for completed samples. (The outputs for completed samples will be multiplied by zero, so the derivative of the [baseline loss](svi_part_iii.ipynb#Neural-Baselines) for these outputs will be zero.) It's OK to do this because in effect we've already removed random choices for completed samples from the inference objective, so there's no need to apply any variance reduction to them.\n",
    "\n",
    "We now have everything we need to complete the implementation of the guide. Our final `guide_step` function will be very similar to `guide_step_improved` introduced above. The only changes are:\n",
    "\n",
    "1. We now call the `baseline_step` helper and pass the baseline value it returns to `pyro.sample`.\n",
    "2. We now mask out the `z_where` and `z_what` choices for complete sample. This serves exactly the same purpose as the masks added to the model. (See the earlier discussion for the motivation behind this change.)\n",
    "\n",
    "We'll also write a `guide` function that will iterate `guide_step` in order to provide a guide for the whole model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "GuideState = namedtuple('GuideState', ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what'])\n",
    "def initial_guide_state(n):\n",
    "    return GuideState(h=torch.zeros(n, 256),\n",
    "                      c=torch.zeros(n, 256),\n",
    "                      bl_h=torch.zeros(n, 256),\n",
    "                      bl_c=torch.zeros(n, 256),\n",
    "                      z_pres=torch.ones(n, 1),\n",
    "                      z_where=torch.zeros(n, 3),\n",
    "                      z_what=torch.zeros(n, 50))\n",
    "\n",
    "def guide_step(t, data, prev):\n",
    "\n",
    "    rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)\n",
    "    h, c = rnn(rnn_input, (prev.h, prev.c))\n",
    "    z_pres_p, z_where_loc, z_where_scale = predict(h)\n",
    "\n",
    "    # Here we compute the baseline value, and pass it to sample.\n",
    "    baseline_value, bl_h, bl_c = baseline_step(data, prev)\n",
    "    z_pres = pyro.sample('z_pres_{}'.format(t),\n",
    "                         dist.Bernoulli(z_pres_p * prev.z_pres)\n",
    "                             .to_event(1),\n",
    "                         infer=dict(baseline=dict(baseline_value=baseline_value.squeeze(-1))))\n",
    "\n",
    "    z_where = pyro.sample('z_where_{}'.format(t),\n",
    "                          dist.Normal(z_where_loc, z_where_scale)\n",
    "                              .mask(z_pres)\n",
    "                              .to_event(1))\n",
    "    \n",
    "    x_att = image_to_object(z_where, data)\n",
    "\n",
    "    z_what_loc, z_what_scale = encode(x_att)\n",
    "\n",
    "    z_what = pyro.sample('z_what_{}'.format(t),\n",
    "                         dist.Normal(z_what_loc, z_what_scale)\n",
    "                             .mask(z_pres)\n",
    "                             .to_event(1))\n",
    "\n",
    "    return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)\n",
    "\n",
    "def guide(data):\n",
    "    # Register networks for optimization.\n",
    "    pyro.module('rnn', rnn),\n",
    "    pyro.module('predict', predict),\n",
    "    pyro.module('encode', encode),\n",
    "    pyro.module('bl_rnn', bl_rnn)\n",
    "    pyro.module('bl_predict', bl_predict)\n",
    "\n",
    "    with pyro.plate('data', data.size(0), subsample_size=64) as indices:\n",
    "        batch = data[indices]\n",
    "        state = initial_guide_state(batch.size(0))\n",
    "        steps = []\n",
    "        for t in range(3):\n",
    "            state = guide_step(t, batch, state)\n",
    "            steps.append(state)\n",
    "        return steps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Putting it all together\n",
    "\n",
    "We have now completed the implementation of the model and the guide. As we have seen in earlier tutorials, we need write only a few more lines of code to begin performing inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i=0, elbo=2806.79\n",
      "i=1, elbo=3656.81\n",
      "i=2, elbo=3222.37\n",
      "i=3, elbo=3872.77\n",
      "i=4, elbo=2818.27\n"
     ]
    }
   ],
   "source": [
    "data = mnist.view(-1, 50 * 50)\n",
    "\n",
    "svi = SVI(model,\n",
    "          guide,\n",
    "          optim.Adam({'lr': 1e-4}),\n",
    "          loss=TraceGraph_ELBO())\n",
    "\n",
    "for i in range(5):\n",
    "    loss = svi.step(data)\n",
    "    print('i={}, elbo={:.2f}'.format(i, loss / data.size(0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One key detail here is that we use a `TraceGraph_ELBO` loss rather than a simpler `Trace_ELBO`. This indicates that we wish to use the gradient estimator that supports data dependent baselines. This estimator also [reduces the variance](svi_part_iii.ipynb#Reducing-Variance-via-Dependency-Structure) of gradient estimates by making use of independence information included in the model. Something similar is implicity used in [1], and is necessary in order to achieve good results on this model.\n",
    "\n",
    "## Results\n",
    "\n",
    "To sanity check our implementation we ran inference using our [standalone implementation](https://github.com/uber/pyro/tree/dev/examples/air) and compared its performance against some of the results reported in [1].\n",
    "\n",
    "Here we show progress made on the ELBO and training set count accuracy during optimization:"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/html"
   },
   "source": [
    "<center>\n",
    "<figure style='padding: 0 0 1em'>\n",
    "<div style='width: 50%; float: left;'><img src=\"_static/img/air/progress_elbo.png\" /></div>\n",
    "<div style='width: 50%; float: left;'><img src=\"_static/img/air/progress_accuracy.png\" /></div>\n",
    "<figcaption style='font-size: 90%; clear: both;'><b>Figure 3:</b> <i>Left:</i> Progress on the evidence lower bound (ELBO) during optimization. <i>Right:</i> Progress on training set count accuracy during optimization.</figcaption>\n",
    "</figure>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Count accuracy reached around 98.7%, which is in the same ballpark as the count accuracy reported in [1]. The value reached on the ELBO differs a little from that reported in [1], which may be due to small differences in the priors used.\n",
    "\n",
    "In the next figure the top row shows ten data points from the test set. The bottom row is a visualization of a single sample from the guide for each of these inputs, that shows the values sampled for `z_pres` and `z_where`. Following [1], the first, second and third steps are displayed using red, green and blue borders respectively. (No blue borders are shown as the guide did not use three steps for any of these samples.) It also shows reconstructions of the input obtained by passing the latent variables sampled from the guide back through the model to generate an output image."
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/html"
   },
   "source": [
    "<center>\n",
    "<figure style='padding: 0 0 1em'>\n",
    "<img src=\"_static/img/air/reconstructions.png\" />\n",
    "<figcaption style='font-size: 90%; padding: 0.5em 0 0'><b>Figure 4:</b> <i>Top row:</i> Data points from the multi-mnist test set. <i>Bottom row:</i> Visualization of samples from the guide and the model's reconstruction of the inputs.</figcaption>\n",
    "</figure>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These results were collected using the following parameters:\n",
    "\n",
    "```\n",
    "python main.py -n 200000 -blr 0.1 --z-pres-prior 0.01 --scale-prior-sd 0.2 --predict-net 200 --bl-predict-net 200 --decoder-output-use-sigmoid --decoder-output-bias -2 --seed 287710\n",
    "```\n",
    "\n",
    "We used Pyro commit `c0b38ad` with PyTorch `0.2.0.post4`. Inference ran for approximately 4 hours on an NVIDIA K80 GPU. (Note that even though we set the random seed, this isn't sufficient to make inference deterministic when using CUDA.)\n",
    "\n",
    "## In practice\n",
    "\n",
    "We found it important to pay attention to the following details in order to achieve good results with AIR.\n",
    "\n",
    "* Inference is unlikely to recover correct object counts unless a small prior success probability for `z_pres` is used. In [1] this [probability was annealed](http://akosiorek.github.io/ml/2017/09/03/implementing-air.html) from a value close to one to `1e-5` (or less) during optimization, though we found that a fixed value of around `0.01` worked well with our implementation.\n",
    "* We initialize the decoder network to generate mostly empty objects initially. (Using the `--decoder-output-bias` argument.) This encourages the guide to explore the use of objects to explain the input early in optimization. Without this each object is a mid-gray square which is heavily penalized by the likelihood, prompting the guide to turn most steps off.\n",
    "* It is reported to be useful in practice to use a different learning rate for the baseline network. This is straight forward to implement in Pyro by tagging modules associated with the baseline network and passing multiple learning rates to the optimizer. (See the section on [optimizers](svi_part_i.ipynb#Optimizers) in part I of the SVI tutorial for more detail.) In [1] a learning rate of `1e-4` was used for the guide network, and a learning rate of `1e-3` was used for the baseline network. We found it necessary to use a larger learning rate for the baseline network in order to make progress on count accuracy at a similar rate to [1]. This difference is likely caused by Pyro setting up a [slightly different baseline loss](https://github.com/uber/pyro/issues/555).\n",
    "\n",
    "\n",
    "## References\n",
    "\n",
    "[1] `Attend, Infer, Repeat: Fast Scene Understanding with Generative Models`\n",
    "<br />&nbsp;&nbsp;&nbsp;&nbsp;\n",
    "S. M. Ali Eslami and Nicolas Heess and Theophane Weber and Yuval Tassa and Koray Kavukcuoglu and Geoffrey E. Hinton\n",
    "\n",
    "[2] `Spatial Transformer Networks`\n",
    "<br />&nbsp;&nbsp;&nbsp;&nbsp;\n",
    "Max Jaderberg and Karen Simonyan and Andrew Zisserman"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
